Merge pull request #592 from corda/aslemmer-rpc-manual-demux

RPC muxing, multithreading, RPC driver, performance tests
This commit is contained in:
Andras Slemmer 2017-05-08 11:18:21 +01:00 committed by GitHub
commit 489661a289
66 changed files with 3281 additions and 1424 deletions

View File

@ -1,3 +1,4 @@
buildscript {
// For sharing constants between builds
Properties constants = new Properties()

View File

@ -3,6 +3,7 @@ package net.corda.client.jfx.model
import com.google.common.net.HostAndPort
import javafx.beans.property.SimpleObjectProperty
import net.corda.client.rpc.CordaRPCClient
import net.corda.client.rpc.CordaRPCClientConfiguration
import net.corda.core.flows.StateMachineRunId
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.StateMachineInfo
@ -52,11 +53,14 @@ class NodeMonitorModel {
* TODO provide an unsubscribe mechanism
*/
fun register(nodeHostAndPort: HostAndPort, username: String, password: String) {
val client = CordaRPCClient(nodeHostAndPort) {
maxRetryInterval = 10.seconds.toMillis()
}
client.start(username, password)
val proxy = client.proxy()
val client = CordaRPCClient(
hostAndPort = nodeHostAndPort,
configuration = CordaRPCClientConfiguration.default.copy(
connectionMaxRetryInterval = 10.seconds
)
)
val connection = client.start(username, password)
val proxy = connection.proxy
val (stateMachines, stateMachineUpdates) = proxy.stateMachinesAndUpdates()
// Extract the flow tracking stream

View File

@ -144,6 +144,23 @@ fun Generator.Companion.doubleRange(from: Double, to: Double): Generator<Double>
from + it.nextDouble() * (to - from)
}
fun Generator.Companion.char() = Generator {
val codePoint = Math.abs(it.nextInt()) % (17 * (1 shl 16))
if (Character.isValidCodePoint(codePoint)) {
return@Generator ErrorOr(codePoint.toChar())
} else {
ErrorOr.of(IllegalStateException("Could not generate valid codepoint"))
}
}
fun Generator.Companion.string(meanSize: Double = 16.0) = replicatePoisson(meanSize, char()).map {
val builder = StringBuilder()
it.forEach {
builder.append(it)
}
builder.toString()
}
fun <A> Generator.Companion.replicate(number: Int, generator: Generator<A>): Generator<List<A>> {
val generators = mutableListOf<Generator<A>>()
for (i in 1..number) {

View File

@ -36,6 +36,7 @@ dependencies {
testCompile "org.assertj:assertj-core:${assertj_version}"
testCompile project(':test-utils')
testCompile project(':client:mock')
// Integration test helpers
integrationTestCompile "junit:junit:$junit_version"

View File

@ -1,15 +1,10 @@
package net.corda.client.rpc
import net.corda.core.contracts.DOLLARS
import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInitiator
import net.corda.core.getOrThrow
import net.corda.core.messaging.FlowHandle
import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.StateMachineUpdate
import net.corda.core.messaging.startFlow
import net.corda.core.messaging.startTrackedFlow
import net.corda.core.messaging.*
import net.corda.core.node.services.ServiceInfo
import net.corda.core.random63BitValue
import net.corda.core.serialization.OpaqueBytes
@ -27,7 +22,9 @@ import org.junit.After
import org.junit.Before
import org.junit.Test
import java.util.*
import kotlin.test.*
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
class CordaRPCClientTest : NodeBasedTest() {
private val rpcUser = User("user1", "test", permissions = setOf(
@ -36,6 +33,11 @@ class CordaRPCClientTest : NodeBasedTest() {
))
private lateinit var node: Node
private lateinit var client: CordaRPCClient
private var connection: CordaRPCConnection? = null
private fun login(username: String, password: String) {
connection = client.start(username, password)
}
@Before
fun setUp() {
@ -45,33 +47,35 @@ class CordaRPCClientTest : NodeBasedTest() {
@After
fun done() {
client.close()
connection?.close()
}
@Test
fun `log in with valid username and password`() {
client.start(rpcUser.username, rpcUser.password)
login(rpcUser.username, rpcUser.password)
}
@Test
fun `log in with unknown user`() {
assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy {
client.start(random63BitValue().toString(), rpcUser.password)
login(random63BitValue().toString(), rpcUser.password)
}
}
@Test
fun `log in with incorrect password`() {
assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy {
client.start(rpcUser.username, random63BitValue().toString())
login(rpcUser.username, random63BitValue().toString())
}
}
@Test
fun `close-send deadlock and premature shutdown on empty observable`() {
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
println("Starting client")
login(rpcUser.username, rpcUser.password)
println("Creating proxy")
println("Starting flow")
val flowHandle = proxy.startTrackedFlow(
val flowHandle = connection!!.proxy.startTrackedFlow(
::CashIssueFlow,
20.DOLLARS, OpaqueBytes.of(0), node.info.legalIdentity, node.info.legalIdentity)
println("Started flow, waiting on result")
@ -83,9 +87,8 @@ class CordaRPCClientTest : NodeBasedTest() {
@Test
fun `FlowException thrown by flow`() {
client.start(rpcUser.username, rpcUser.password)
val proxy = client.proxy()
val handle = proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity)
login(rpcUser.username, rpcUser.password)
val handle = connection!!.proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity)
// TODO Restrict this to CashException once RPC serialisation has been fixed
assertThatExceptionOfType(FlowException::class.java).isThrownBy {
handle.returnValue.getOrThrow()
@ -94,9 +97,8 @@ class CordaRPCClientTest : NodeBasedTest() {
@Test
fun `check basic flow has no progress`() {
client.start(rpcUser.username, rpcUser.password)
val proxy = client.proxy()
proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity).use {
login(rpcUser.username, rpcUser.password)
connection!!.proxy.startFlow(::CashPaymentFlow, 100.DOLLARS, node.info.legalIdentity).use {
assertFalse(it is FlowProgressHandle<*>)
assertTrue(it is FlowHandle<*>)
}
@ -104,7 +106,8 @@ class CordaRPCClientTest : NodeBasedTest() {
@Test
fun `get cash balances`() {
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
login(rpcUser.username, rpcUser.password)
val proxy = connection!!.proxy
val startCash = proxy.getCashBalances()
assertTrue(startCash.isEmpty(), "Should not start with any cash")
@ -123,7 +126,8 @@ class CordaRPCClientTest : NodeBasedTest() {
@Test
fun `flow initiator via RPC`() {
val proxy = createRpcProxy(rpcUser.username, rpcUser.password)
login(rpcUser.username, rpcUser.password)
val proxy = connection!!.proxy
val smUpdates = proxy.stateMachinesAndUpdates()
var countRpcFlows = 0
var countShellFlows = 0
@ -148,11 +152,4 @@ class CordaRPCClientTest : NodeBasedTest() {
assertEquals(2, countRpcFlows)
assertEquals(1, countShellFlows)
}
private fun createRpcProxy(username: String, password: String): CordaRPCOps {
println("Starting client")
client.start(username, password)
println("Creating proxy")
return client.proxy()
}
}

View File

@ -0,0 +1,169 @@
package net.corda.client.rpc
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.util.concurrent.Futures
import net.corda.core.messaging.RPCOps
import net.corda.core.millis
import net.corda.core.random63BitValue
import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.RPCApi
import net.corda.nodeapi.RPCKryo
import net.corda.testing.*
import org.apache.activemq.artemis.api.core.SimpleString
import org.junit.Test
import rx.Observable
import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject
import java.time.Duration
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
class RPCStabilityTests {
interface LeakObservableOps: RPCOps {
fun leakObservable(): Observable<Nothing>
}
@Test
fun `client cleans up leaked observables`() {
rpcDriver {
val leakObservableOpsImpl = object : LeakObservableOps {
val leakedUnsubscribedCount = AtomicInteger(0)
override val protocolVersion = 0
override fun leakObservable(): Observable<Nothing> {
return PublishSubject.create<Nothing>().doOnUnsubscribe {
leakedUnsubscribedCount.incrementAndGet()
}
}
}
val server = startRpcServer<LeakObservableOps>(ops = leakObservableOpsImpl)
val proxy = startRpcClient<LeakObservableOps>(server.get().hostAndPort).get()
// Leak many observables
val N = 200
(1..N).toList().parallelStream().forEach {
proxy.leakObservable()
}
// In a loop force GC and check whether the server is notified
while (true) {
System.gc()
if (leakObservableOpsImpl.leakedUnsubscribedCount.get() == N) break
Thread.sleep(100)
}
}
}
interface TrackSubscriberOps : RPCOps {
fun subscribe(): Observable<Unit>
}
/**
* In this test we create a number of out of process RPC clients that call [TrackSubscriberOps.subscribe] in a loop.
*/
@Test
fun `server cleans up queues after disconnected clients`() {
rpcDriver {
val trackSubscriberOpsImpl = object : TrackSubscriberOps {
override val protocolVersion = 0
val subscriberCount = AtomicInteger(0)
val trackSubscriberCountObservable = UnicastSubject.create<Unit>().share().
doOnSubscribe { subscriberCount.incrementAndGet() }.
doOnUnsubscribe { subscriberCount.decrementAndGet() }
override fun subscribe(): Observable<Unit> {
return trackSubscriberCountObservable
}
}
val server = startRpcServer<TrackSubscriberOps>(
configuration = RPCServerConfiguration.default.copy(
reapInterval = 100.millis
),
ops = trackSubscriberOpsImpl
).get()
val numberOfClients = 4
val clients = Futures.allAsList((1 .. numberOfClients).map {
startRandomRpcClient<TrackSubscriberOps>(server.hostAndPort)
}).get()
// Poll until all clients connect
pollUntilClientNumber(server, numberOfClients)
pollUntilTrue("number of times subscribe() has been called") { trackSubscriberOpsImpl.subscriberCount.get() >= 100 }.get()
// Kill one client
clients[0].destroyForcibly()
pollUntilClientNumber(server, numberOfClients - 1)
// Kill the rest
(1 .. numberOfClients - 1).forEach {
clients[it].destroyForcibly()
}
pollUntilClientNumber(server, 0)
// Now poll until the server detects the disconnects and unsubscribes from all obserables.
pollUntilTrue("number of times subscribe() has been called") { trackSubscriberOpsImpl.subscriberCount.get() == 0 }.get()
}
}
interface SlowConsumerRPCOps : RPCOps {
fun streamAtInterval(interval: Duration, size: Int): Observable<ByteArray>
}
class SlowConsumerRPCOpsImpl : SlowConsumerRPCOps {
override val protocolVersion = 0
override fun streamAtInterval(interval: Duration, size: Int): Observable<ByteArray> {
val chunk = ByteArray(size)
return Observable.interval(interval.toMillis(), TimeUnit.MILLISECONDS).map { chunk }
}
}
val dummyObservableSerialiser = object : Serializer<Observable<Any>>() {
override fun write(kryo: Kryo?, output: Output?, `object`: Observable<Any>?) {
}
override fun read(kryo: Kryo?, input: Input?, type: Class<Observable<Any>>?): Observable<Any> {
return Observable.empty()
}
}
@Test
fun `slow consumers are kicked`() {
val kryoPool = KryoPool.Builder { RPCKryo(dummyObservableSerialiser) }.build()
rpcDriver {
val server = startRpcServer(maxBufferedBytesPerClient = 10 * 1024 * 1024, ops = SlowConsumerRPCOpsImpl()).get()
// Construct an RPC session manually so that we can hang in the message handler
val myQueue = "${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.test.${random63BitValue()}"
val session = startArtemisSession(server.hostAndPort)
session.createTemporaryQueue(myQueue, myQueue)
val consumer = session.createConsumer(myQueue, null, -1, -1, false)
consumer.setMessageHandler {
Thread.sleep(50) // 5x slower than the server producer
it.acknowledge()
}
val producer = session.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME)
session.start()
pollUntilClientNumber(server, 1)
val message = session.createMessage(false)
val request = RPCApi.ClientToServer.RpcRequest(
clientAddress = SimpleString(myQueue),
id = RPCApi.RpcRequestId(random63BitValue()),
methodName = SlowConsumerRPCOps::streamAtInterval.name,
arguments = listOf(10.millis, 123456)
)
request.writeToClientMessage(kryoPool, message)
producer.send(message)
session.commit()
// We are consuming slower than the server is producing, so we should be kicked after a while
pollUntilClientNumber(server, 0)
}
}
}
fun RPCDriverExposedDSLInterface.pollUntilClientNumber(server: RpcServerHandle, expected: Int) {
pollUntilTrue("number of RPC clients to become $expected") {
val clientAddresses = server.serverControl.addressNames.filter { it.startsWith(RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX) }
clientAddresses.size == expected
}.get()
}

View File

@ -1,168 +1,49 @@
package net.corda.client.rpc
import com.google.common.net.HostAndPort
import net.corda.core.ThreadBox
import net.corda.core.logElapsedTime
import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.minutes
import net.corda.core.seconds
import net.corda.core.utilities.loggerFor
import net.corda.nodeapi.ArtemisMessagingComponent
import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport
import net.corda.nodeapi.ConnectionDirection
import net.corda.nodeapi.RPCException
import net.corda.nodeapi.config.SSLConfiguration
import net.corda.nodeapi.rpcLog
import org.apache.activemq.artemis.api.core.ActiveMQException
import org.apache.activemq.artemis.api.core.client.ActiveMQClient
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.apache.activemq.artemis.api.core.client.ClientSessionFactory
import org.apache.activemq.artemis.api.core.client.ServerLocator
import rx.Observable
import java.io.Closeable
import java.time.Duration
import javax.annotation.concurrent.ThreadSafe
/**
* An RPC client connects to the specified server and allows you to make calls to the server that perform various
* useful tasks. See the documentation for [proxy] or review the docsite to learn more about how this API works.
*
* @param host The hostname and messaging port of the node.
* @param config If specified, the SSL configuration to use. If not specified, SSL will be disabled and the node will only be authenticated on non-SSL RPC port, the RPC traffic with not be encrypted when SSL is disabled.
*/
@ThreadSafe
class CordaRPCClient(val host: HostAndPort, override val config: SSLConfiguration? = null, val serviceConfigurationOverride: (ServerLocator.() -> Unit)? = null) : Closeable, ArtemisMessagingComponent() {
private companion object {
val log = loggerFor<CordaRPCClient>()
/** 10 MiB maximum allowed file size for attachments, including message headers. TODO: acquire this value from Network Map when supported. */
@JvmStatic val MAX_FILE_SIZE = 10485760
class CordaRPCConnection internal constructor(
connection: RPCClient.RPCConnection<CordaRPCOps>
) : RPCClient.RPCConnection<CordaRPCOps> by connection
data class CordaRPCClientConfiguration(
val connectionMaxRetryInterval: Duration
) {
internal fun toRpcClientConfiguration(): RPCClientConfiguration {
return RPCClientConfiguration.default.copy(
connectionMaxRetryInterval = connectionMaxRetryInterval
)
}
companion object {
@JvmStatic
val default = CordaRPCClientConfiguration(
connectionMaxRetryInterval = RPCClientConfiguration.default.connectionMaxRetryInterval
)
}
}
class CordaRPCClient(
hostAndPort: HostAndPort,
sslConfiguration: SSLConfiguration? = null,
configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default
) {
private val rpcClient = RPCClient<CordaRPCOps>(
tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration),
configuration.toRpcClientConfiguration()
)
fun start(username: String, password: String): CordaRPCConnection {
return CordaRPCConnection(rpcClient.start(CordaRPCOps::class.java, username, password))
}
// TODO: Certificate handling for clients needs more work.
private inner class State {
var running = false
lateinit var sessionFactory: ClientSessionFactory
lateinit var session: ClientSession
lateinit var clientImpl: CordaRPCClientImpl
}
private val state = ThreadBox(State())
/**
* Opens the connection to the server with the given username and password, then returns itself.
* Registers a JVM shutdown hook to cleanly disconnect.
*/
@Throws(ActiveMQException::class)
fun start(username: String, password: String): CordaRPCClient {
state.locked {
check(!running)
log.logElapsedTime("Startup") {
checkStorePasswords()
val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport(ConnectionDirection.Outbound(), host, config, enableSSL = config != null)).apply {
// TODO: Put these in config file or make it user configurable?
threadPoolMaxSize = 1
confirmationWindowSize = 100000 // a guess
retryInterval = 5.seconds.toMillis()
retryIntervalMultiplier = 1.5 // Exponential backoff
maxRetryInterval = 3.minutes.toMillis()
minLargeMessageSize = MAX_FILE_SIZE
serviceConfigurationOverride?.invoke(this)
}
sessionFactory = serverLocator.createSessionFactory()
session = sessionFactory.createSession(username, password, false, true, true, serverLocator.isPreAcknowledge, serverLocator.ackBatchSize)
session.start()
clientImpl = CordaRPCClientImpl(session, state.lock, username)
running = true
}
}
Runtime.getRuntime().addShutdownHook(Thread {
close()
})
return this
}
/**
* A convenience function that opens a connection with the given credentials, executes the given code block with all
* available RPCs in scope and shuts down the RPC connection again. It's meant for quick prototyping and demos. For
* more control you probably want to control the lifecycle of the client and proxies independently, as well as
* configuring a timeout and other such features via the [proxy] method.
*
* After this method returns the client is closed and can't be restarted.
*/
@Throws(ActiveMQException::class)
fun <T> use(username: String, password: String, block: CordaRPCOps.() -> T): T {
require(!state.locked { running })
start(username, password)
(this as Closeable).use {
return proxy().block()
}
}
/** Shuts down the client and lets the server know it can free the used resources (in a nice way). */
override fun close() {
state.locked {
if (!running) return
session.close()
sessionFactory.close()
running = false
}
}
/**
* Returns a fresh proxy that lets you invoke RPCs on the server. Calls on it block, and if the server throws an
* exception then it will be rethrown on the client. Proxies are thread safe but only one RPC can be in flight at
* once. If you'd like to perform multiple RPCs in parallel, use this function multiple times to get multiple
* proxies.
*
* Creation of a proxy is a somewhat expensive operation that involves calls to the server, so if you want to do
* calls from many threads at once you should cache one proxy per thread and reuse them. This function itself is
* thread safe though so requires no extra synchronisation.
*
* RPC sends and receives are logged on the net.corda.rpc logger.
*
* By default there are no timeouts on calls. This is deliberate, RPCs without timeouts can survive restarts,
* maintenance downtime and moves of the server. RPCs can survive temporary losses or changes in client connectivity,
* like switching between wifi networks. You can specify a timeout on the level of a proxy. If a call times
* out it will throw [RPCException.Deadline].
*
* The [CordaRPCOps] defines what client RPCs are available. If an RPC returns an [Observable] anywhere in the
* object graph returned then the server-side observable is transparently linked to a messaging queue, and that
* queue linked to another observable on the client side here. *You are expected to use it*. The server will begin
* buffering messages immediately that it will expect you to drain by subscribing to the returned observer. You can
* opt-out of this by simply casting the [Observable] to [Closeable] or [AutoCloseable] and then calling the close
* method on it. You don't have to explicitly close the observable if you actually subscribe to it: it will close
* itself and free up the server-side resources either when the client or JVM itself is shutdown, or when there are
* no more subscribers to it. Once all the subscribers to a returned observable are unsubscribed, the observable is
* closed and you can't then re-subscribe again: you'll have to re-request a fresh observable with another RPC.
*
* The proxy and linked observables consume some small amount of resources on the server. It's OK to just exit your
* process and let the server clean up, but in a long running process where you only need something for a short
* amount of time it is polite to cast the objects to [Closeable] or [AutoCloseable] and close it when you are done.
* Finalizers are in place to warn you if you lose a reference to an unclosed proxy or observable.
*
* @throws RPCException if the server version is too low or if the server isn't reachable within the given time.
*/
@JvmOverloads
@Throws(RPCException::class)
fun proxy(timeout: Duration? = null, minVersion: Int = 0): CordaRPCOps {
return state.locked {
check(running) { "Client must have been started first" }
log.logElapsedTime("Proxy build") {
clientImpl.proxyFor(CordaRPCOps::class.java, timeout, minVersion)
}
}
}
@Suppress("UNUSED")
private fun finalize() {
state.locked {
if (running) {
rpcLog.warn("A CordaMQClient is being finalised whilst still running, did you forget to call close?")
close()
}
}
inline fun <A> use(username: String, password: String, block: (CordaRPCConnection) -> A): A {
return start(username, password).use(block)
}
}

View File

@ -1,418 +0,0 @@
package net.corda.client.rpc
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.cache.CacheBuilder
import net.corda.core.ErrorOr
import net.corda.core.bufferUntilSubscribed
import net.corda.core.messaging.RPCOps
import net.corda.core.messaging.RPCReturnsObservables
import net.corda.core.random63BitValue
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.debug
import net.corda.nodeapi.*
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ClientConsumer
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientProducer
import org.apache.activemq.artemis.api.core.client.ClientSession
import rx.Observable
import rx.subjects.PublishSubject
import java.io.Closeable
import java.lang.ref.WeakReference
import java.lang.reflect.InvocationHandler
import java.lang.reflect.Method
import java.lang.reflect.Proxy
import java.time.Duration
import java.util.*
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.GuardedBy
import javax.annotation.concurrent.ThreadSafe
import kotlin.concurrent.withLock
import kotlin.reflect.jvm.javaMethod
/**
* Core RPC engine implementation, to learn how to use RPC you should be looking at [CordaRPCClient].
*
* # Design notes
*
* The way RPCs are handled is fairly standard except for the handling of observables. When an RPC might return
* an [Observable] it is specially tagged. This causes the client to create a new transient queue for the
* receiving of observables and their observations with a random ID in the name. This ID is sent to the server in
* a message header. All observations are sent via this single queue.
*
* The reason for doing it this way and not the more obvious approach of one-queue-per-observable is that we want
* the queues to be *transient*, meaning their lifetime in the broker is tied to the session that created them.
* A server side observable and its associated queue is not a cost-free thing, let alone the memory and resources
* needed to actually generate the observations themselves, therefore we want to ensure these cannot leak. A
* transient queue will be deleted automatically if the client session terminates, which by default happens on
* disconnect but can also be configured to happen after a short delay (this allows clients to e.g. switch IP
* address). On the server the deletion of the observations queue triggers unsubscription from the associated
* observables, which in turn may then be garbage collected.
*
* Creating a transient queue requires a roundtrip to the broker and thus doing an RPC that could return
* observables takes two server roundtrips instead of one. That's why we require RPCs to be marked with
* [RPCReturnsObservables] as needing this special treatment instead of always doing it.
*
* If the Artemis/JMS APIs allowed us to create transient queues assigned to someone else then we could
* potentially use a different design in which the node creates new transient queues (one per observable) on the
* fly. The client would then have to watch out for this and start consuming those queues as they were created.
*
* We use one queue per RPC because we don't know ahead of time how many observables the server might return and
* often the server doesn't know either, which pushes towards a single queue design, but at the same time the
* processing of observations returned by an RPC might be striped across multiple threads and we'd like
* backpressure management to not be scoped per client process but with more granularity. So we end up with
* a compromise where the unit of backpressure management is the response to a single RPC.
*
* TODO: Backpressure isn't propagated all the way through the MQ broker at the moment.
*/
class CordaRPCClientImpl(private val session: ClientSession,
private val sessionLock: ReentrantLock,
private val username: String) {
companion object {
private val closeableCloseMethod = Closeable::close.javaMethod
private val autocloseableCloseMethod = AutoCloseable::close.javaMethod
}
/**
* Builds a proxy for the given type, which must descend from [RPCOps].
*
* @see CordaRPCClient.proxy for more information about how to use the proxies.
*/
fun <T : RPCOps> proxyFor(rpcInterface: Class<T>, timeout: Duration? = null, minVersion: Int = 0): T {
sessionLock.withLock {
if (producer == null)
producer = session.createProducer()
}
val proxyImpl = RPCProxyHandler(timeout)
@Suppress("UNCHECKED_CAST")
val proxy = Proxy.newProxyInstance(rpcInterface.classLoader, arrayOf(rpcInterface, Closeable::class.java), proxyImpl) as T
proxyImpl.serverProtocolVersion = proxy.protocolVersion
if (minVersion > proxyImpl.serverProtocolVersion)
throw RPCException("Requested minimum protocol version $minVersion is higher than the server's supported protocol version (${proxyImpl.serverProtocolVersion})")
return proxy
}
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//region RPC engine
//
// You can find docs on all this in the api doc for the proxyFor method, and in the docsite.
// Utility to quickly suck out the contents of an Artemis message. There's probably a more efficient way to
// do this.
private fun <T : Any> ClientMessage.deserialize(kryo: Kryo): T = ByteArray(bodySize).apply { bodyBuffer.readBytes(this) }.deserialize(kryo)
// We by default use a weak reference so GC can happen, otherwise they persist for the life of the client.
@GuardedBy("sessionLock")
private val addressToQueuedObservables = CacheBuilder.newBuilder().weakValues().build<String, QueuedObservable>()
// This is used to hold a reference counted hard reference when we know there are subscribers.
private val hardReferencesToQueuedObservables = Collections.synchronizedSet(mutableSetOf<QueuedObservable>())
private var producer: ClientProducer? = null
class ObservableDeserializer : Serializer<Observable<Any>>() {
override fun read(kryo: Kryo, input: Input, type: Class<Observable<Any>>): Observable<Any> {
val qName = kryo.context[RPCKryoQNameKey] as String
val rpcName = kryo.context[RPCKryoMethodNameKey] as String
val rpcLocation = kryo.context[RPCKryoLocationKey] as Throwable
val rpcClient = kryo.context[RPCKryoClientKey] as CordaRPCClientImpl
val handle = input.readInt(true)
val ob = rpcClient.sessionLock.withLock {
rpcClient.addressToQueuedObservables.getIfPresent(qName) ?: rpcClient.QueuedObservable(qName, rpcName, rpcLocation).apply {
rpcClient.addressToQueuedObservables.put(qName, this)
}
}
val result = ob.getForHandle(handle)
rpcLog.debug { "Deserializing and connecting a new observable for $rpcName on $qName: $result" }
return result
}
override fun write(kryo: Kryo, output: Output, `object`: Observable<Any>) {
throw UnsupportedOperationException("not implemented")
}
}
/**
* The proxy class returned to the client is auto-generated on the fly by the java.lang.reflect Proxy
* infrastructure. The JDK Proxy class writes bytecode into memory for a class that implements the requested
* interfaces and then routes all method calls to the invoke method below in a conveniently reified form.
* We can then easily take the data about the method call and turn it into an RPC. This avoids the need
* for the compile-time code generation which is so common in RPC systems.
*/
@ThreadSafe
private inner class RPCProxyHandler(private val timeout: Duration?) : InvocationHandler, Closeable {
private val proxyId = random63BitValue()
private val consumer: ClientConsumer
var serverProtocolVersion = 0
init {
val proxyAddress = constructAddress(proxyId)
consumer = sessionLock.withLock {
session.createTemporaryQueue(proxyAddress, proxyAddress)
session.createConsumer(proxyAddress)
}
}
private fun constructAddress(addressId: Long) = "${ArtemisMessagingComponent.CLIENTS_PREFIX}$username.rpc.$addressId"
@Synchronized
override fun invoke(proxy: Any, method: Method, args: Array<out Any>?): Any? {
if (isCloseInvocation(method)) {
close()
return null
}
if (method.name == "toString" && args == null)
return "Client RPC proxy"
if (consumer.isClosed)
throw RPCException("RPC Proxy is closed")
// All invoked methods on the proxy end up here.
val location = Throwable()
rpcLog.debug {
val argStr = args?.joinToString() ?: ""
"-> RPC -> ${method.name}($argStr): ${method.returnType}"
}
checkMethodVersion(method)
val msg: ClientMessage = createMessage(method)
// We could of course also check the return type of the method to see if it's Observable, but I'd
// rather haved the annotation be used consistently.
val returnsObservables = method.isAnnotationPresent(RPCReturnsObservables::class.java)
val kryo = if (returnsObservables) maybePrepareForObservables(location, method, msg) else createRPCKryoForDeserialization(this@CordaRPCClientImpl)
val next: ErrorOr<*> = try {
sendRequest(args, msg)
receiveResponse(kryo, method, timeout)
} finally {
releaseRPCKryoForDeserialization(kryo)
}
rpcLog.debug { "<- RPC <- ${method.name} = $next" }
return unwrapOrThrow(next)
}
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
private fun unwrapOrThrow(next: ErrorOr<*>): Any? {
val ex = next.error
if (ex != null) {
// Replace the stack trace because that's an implementation detail of the server that isn't so
// helpful to the user who wants to see where the error was on their side, and serialising stack
// frame objects is a bit annoying. We slice it here to avoid the invoke() machinery being exposed.
// The resulting exception looks like it was thrown from inside the called method.
(ex as java.lang.Throwable).stackTrace = java.lang.Throwable().stackTrace.let { it.sliceArray(1..it.size - 1) }
throw ex
} else {
return next.value
}
}
private fun receiveResponse(kryo: Kryo, method: Method, timeout: Duration?): ErrorOr<*> {
val artemisMessage: ClientMessage =
if (timeout == null)
consumer.receive() ?: throw ActiveMQObjectClosedException()
else
consumer.receive(timeout.toMillis()) ?: throw RPCException.DeadlineExceeded(method.name)
artemisMessage.acknowledge()
val next = artemisMessage.deserialize<ErrorOr<*>>(kryo)
return next
}
private fun sendRequest(args: Array<out Any>?, msg: ClientMessage) {
sessionLock.withLock {
val argsKryo = createRPCKryoForDeserialization(this@CordaRPCClientImpl)
val serializedArgs = try {
(args ?: emptyArray<Any?>()).serialize(argsKryo)
} catch (e: KryoException) {
throw RPCException("Could not serialize RPC arguments", e)
} finally {
releaseRPCKryoForDeserialization(argsKryo)
}
msg.writeBodyBufferBytes(serializedArgs.bytes)
producer!!.send(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE, msg)
}
}
private fun maybePrepareForObservables(location: Throwable, method: Method, msg: ClientMessage): Kryo {
// Create a temporary queue just for the emissions on any observables that are returned.
val observationsId = random63BitValue()
val observationsQueueName = constructAddress(observationsId)
session.createTemporaryQueue(observationsQueueName, observationsQueueName)
msg.putLongProperty(ClientRPCRequestMessage.OBSERVATIONS_TO, observationsId)
// And make sure that we deserialise observable handles so that they're linked to the right
// queue. Also record a bit of metadata for debugging purposes.
return createRPCKryoForDeserialization(this@CordaRPCClientImpl, observationsQueueName, method.name, location)
}
private fun createMessage(method: Method): ClientMessage {
return session.createMessage(false).apply {
putStringProperty(ClientRPCRequestMessage.METHOD_NAME, method.name)
putLongProperty(ClientRPCRequestMessage.REPLY_TO, proxyId)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
}
}
private fun checkMethodVersion(method: Method) {
val methodVersion = method.getAnnotation(RPCSinceVersion::class.java)?.version ?: 0
if (methodVersion > serverProtocolVersion)
throw UnsupportedOperationException("Method ${method.name} was added in RPC protocol version $methodVersion but the server is running $serverProtocolVersion")
}
private fun isCloseInvocation(method: Method) = method == closeableCloseMethod || method == autocloseableCloseMethod
override fun close() {
consumer.close()
sessionLock.withLock { session.deleteQueue(constructAddress(proxyId)) }
}
override fun toString() = "Corda RPC Proxy listening on queue ${constructAddress(proxyId)}"
}
/**
* When subscribed to, starts consuming from the given queue name and demultiplexing the observables being
* sent to it. The server queue is moved into in-memory buffers (one per attached server-side observable)
* until drained through a subscription. When the subscriptions are all gone, the server-side queue is deleted.
*/
@ThreadSafe
private inner class QueuedObservable(private val qName: String,
private val rpcName: String,
private val rpcLocation: Throwable) {
private val root = PublishSubject.create<MarshalledObservation>()
private val rootShared = root.doOnUnsubscribe { close() }.share()
// This could be made more efficient by using a specialised IntMap
// When handling this map we don't synchronise on [this], otherwise there is a race condition between close() and deliver()
private val observables = Collections.synchronizedMap(HashMap<Int, Observable<Any>>())
@GuardedBy("sessionLock")
private var consumer: ClientConsumer? = null
private val referenceCount = AtomicInteger(0)
// We have to create a weak reference, otherwise we cannot be GC'd.
init {
val weakThis = WeakReference<QueuedObservable>(this)
consumer = sessionLock.withLock { session.createConsumer(qName) }.setMessageHandler { weakThis.get()?.deliver(it) }
}
/**
* We have to reference count subscriptions to the returned [Observable]s to prevent early GC because we are
* weak referenced.
*
* Derived [Observables] (e.g. filtered etc) hold a strong reference to the original, but for example, if
* the pattern as follows is used, the original passes out of scope and the direction of reference is from the
* original to the [Observer]. We use the reference counting to allow for this pattern.
*
* val observationsSubject = PublishSubject.create<Observation>()
* originalObservable.subscribe(observationsSubject)
* return observationsSubject
*/
private fun refCountUp() {
if (referenceCount.andIncrement == 0) {
hardReferencesToQueuedObservables.add(this)
}
}
private fun refCountDown() {
if (referenceCount.decrementAndGet() == 0) {
hardReferencesToQueuedObservables.remove(this)
}
}
fun getForHandle(handle: Int): Observable<Any> {
synchronized(observables) {
return observables.getOrPut(handle) {
/**
* Note that the order of bufferUntilSubscribed() -> dematerialize() is very important here.
*
* In particular doing it the other way around may result in the following edge case:
* The RPC returns two (or more) Observables. The first Observable unsubscribes *during serialisation*,
* before the second one is hit, causing the [rootShared] to unsubscribe and consequently closing
* the underlying artemis queue, even though the second Observable was not even registered.
*
* The buffer -> dematerialize order ensures that the Observable may not unsubscribe until the caller
* subscribes, which must be after full deserialisation and registering of all top level Observables.
*
* In addition, when subscribe and unsubscribe is called on the [Observable] returned here, we
* reference count a hard reference to this [QueuedObservable] to prevent premature GC.
*/
rootShared.filter { it.forHandle == handle }.map { it.what }.bufferUntilSubscribed().dematerialize<Any>().doOnSubscribe { refCountUp() }.doOnUnsubscribe { refCountDown() }.share()
}
}
}
private fun deliver(msg: ClientMessage) {
sessionLock.withLock { msg.acknowledge() }
val kryo = createRPCKryoForDeserialization(this@CordaRPCClientImpl, qName, rpcName, rpcLocation)
val received: MarshalledObservation = try {
msg.deserialize(kryo)
} finally {
releaseRPCKryoForDeserialization(kryo)
}
rpcLog.debug { "<- Observable [$rpcName] <- Received $received" }
synchronized(observables) {
// Force creation of the buffer if it doesn't already exist.
getForHandle(received.forHandle)
root.onNext(received)
}
}
fun close() {
sessionLock.withLock {
if (consumer != null) {
rpcLog.debug("Closing queue observable for call to $rpcName : $qName")
consumer?.close()
consumer = null
session.deleteQueue(qName)
}
}
}
@Suppress("UNUSED")
fun finalize() {
val closed = sessionLock.withLock {
if (consumer != null) {
consumer!!.close()
consumer = null
true
} else
false
}
if (closed) {
rpcLog.warn("""A hot observable returned from an RPC ($rpcName) was never subscribed to.
This wastes server-side resources because it was queueing observations for retrieval.
It is being closed now, but please adjust your code to call .notUsed() on the observable
to close it explicitly. (Java users: subscribe to it then unsubscribe). This warning
will appear less frequently in future versions of the platform and you can ignore it
if you want to.
""".trimIndent().replace('\n', ' '), rpcLocation)
}
}
}
//endregion
}
private val rpcDesKryoPool = KryoPool.Builder { RPCKryo(CordaRPCClientImpl.ObservableDeserializer()) }.build()
fun createRPCKryoForDeserialization(rpcClient: CordaRPCClientImpl, qName: String? = null, rpcName: String? = null, rpcLocation: Throwable? = null): Kryo {
val kryo = rpcDesKryoPool.borrow()
kryo.context.put(RPCKryoClientKey, rpcClient)
kryo.context.put(RPCKryoQNameKey, qName)
kryo.context.put(RPCKryoMethodNameKey, rpcName)
kryo.context.put(RPCKryoLocationKey, rpcLocation)
return kryo
}
fun releaseRPCKryoForDeserialization(kryo: Kryo) {
rpcDesKryoPool.release(kryo)
}

View File

@ -0,0 +1,169 @@
package net.corda.client.rpc.internal
import com.google.common.net.HostAndPort
import net.corda.core.logElapsedTime
import net.corda.core.messaging.RPCOps
import net.corda.core.minutes
import net.corda.core.random63BitValue
import net.corda.core.seconds
import net.corda.core.utilities.loggerFor
import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport
import net.corda.nodeapi.ConnectionDirection
import net.corda.nodeapi.RPCApi
import net.corda.nodeapi.RPCException
import net.corda.nodeapi.config.SSLConfiguration
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.api.core.client.ActiveMQClient
import java.io.Closeable
import java.lang.reflect.Proxy
import java.time.Duration
/**
* This configuration may be used to tweak the internals of the RPC client.
*/
data class RPCClientConfiguration(
/** The minimum protocol version required from the server */
val minimumServerProtocolVersion: Int,
/**
* If set to true the client will track RPC call sites. If an error occurs subsequently during the RPC or in a
* returned Observable stream the stack trace of the originating RPC will be shown as well. Note that
* constructing call stacks is a moderately expensive operation.
*/
val trackRpcCallSites: Boolean,
/**
* The interval of unused observable reaping. Leaked Observables (unused ones) are detected using weak references
* and are cleaned up in batches in this interval. If set too large it will waste server side resources for this
* duration. If set too low it wastes client side cycles.
*/
val reapInterval: Duration,
/** The number of threads to use for observations (for executing [Observable.onNext]) */
val observationExecutorPoolSize: Int,
/** The maximum number of producers to create to handle outgoing messages */
val producerPoolBound: Int,
/**
* Determines the concurrency level of the Observable Cache. This is exposed because it implicitly determines
* the limit on the number of leaked observables reaped because of garbage collection per reaping.
* See the implementation of [com.google.common.cache.LocalCache] for details.
*/
val cacheConcurrencyLevel: Int,
/** The retry interval of artemis connections in milliseconds */
val connectionRetryInterval: Duration,
/** The retry interval multiplier for exponential backoff */
val connectionRetryIntervalMultiplier: Double,
/** Maximum retry interval */
val connectionMaxRetryInterval: Duration,
/** Maximum file size */
val maxFileSize: Int
) {
companion object {
@JvmStatic
val default = RPCClientConfiguration(
minimumServerProtocolVersion = 0,
trackRpcCallSites = false,
reapInterval = 1.seconds,
observationExecutorPoolSize = 4,
producerPoolBound = 1,
cacheConcurrencyLevel = 8,
connectionRetryInterval = 5.seconds,
connectionRetryIntervalMultiplier = 1.5,
connectionMaxRetryInterval = 3.minutes,
/** 10 MiB maximum allowed file size for attachments, including message headers. TODO: acquire this value from Network Map when supported. */
maxFileSize = 10485760
)
}
}
/**
* An RPC client that may be used to create connections to an RPC server.
*
* @param transport The Artemis transport to use to connect to the server.
* @param rpcConfiguration Configuration used to tweak client behaviour.
*/
class RPCClient<I : RPCOps>(
val transport: TransportConfiguration,
val rpcConfiguration: RPCClientConfiguration = RPCClientConfiguration.default
) {
constructor(
hostAndPort: HostAndPort,
sslConfiguration: SSLConfiguration? = null,
configuration: RPCClientConfiguration = RPCClientConfiguration.default
) : this(tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), configuration)
companion object {
private val log = loggerFor<RPCClient<*>>()
}
/**
* Holds a proxy object implementing [I] that forwards requests to the RPC server.
*
* [Closeable.close] may be used to shut down the connection and release associated resources.
*/
interface RPCConnection<out I : RPCOps> : Closeable {
val proxy: I
/** The RPC protocol version reported by the server */
val serverProtocolVersion: Int
}
/**
* Returns an [RPCConnection] containing a proxy that lets you invoke RPCs on the server. Calls on it block, and if
* the server throws an exception then it will be rethrown on the client. Proxies are thread safe and may be used to
* invoke multiple RPCs in parallel.
*
* RPC sends and receives are logged on the net.corda.rpc logger.
*
* The [RPCOps] defines what client RPCs are available. If an RPC returns an [Observable] anywhere in the object
* graph returned then the server-side observable is transparently forwarded to the client side here.
* *You are expected to use it*. The server will begin buffering messages immediately that it will expect you to
* drain by subscribing to the returned observer. You can opt-out of this by simply calling the
* [net.corda.client.rpc.notUsed] method on it. You don't have to explicitly close the observable if you actually
* subscribe to it: it will close itself and free up the server-side resources either when the client or JVM itself
* is shutdown, or when there are no more subscribers to it. Once all the subscribers to a returned observable are
* unsubscribed or the observable completes successfully or with an error, the observable is closed and you can't
* then re-subscribe again: you'll have to re-request a fresh observable with another RPC.
*
* @param rpcOpsClass The [Class] of the RPC interface.
* @param username The username to authenticate with.
* @param password The password to authenticate with.
* @throws RPCException if the server version is too low or if the server isn't reachable within the given time.
*/
fun start(
rpcOpsClass: Class<I>,
username: String,
password: String
): RPCConnection<I> {
return log.logElapsedTime("Startup") {
val clientAddress = SimpleString("${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$username.${random63BitValue()}")
val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(transport).apply {
retryInterval = rpcConfiguration.connectionRetryInterval.toMillis()
retryIntervalMultiplier = rpcConfiguration.connectionRetryIntervalMultiplier
maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval.toMillis()
minLargeMessageSize = rpcConfiguration.maxFileSize
}
val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass)
proxyHandler.start()
@Suppress("UNCHECKED_CAST")
val ops = Proxy.newProxyInstance(rpcOpsClass.classLoader, arrayOf(rpcOpsClass), proxyHandler) as I
val serverProtocolVersion = ops.protocolVersion
if (serverProtocolVersion < rpcConfiguration.minimumServerProtocolVersion) {
throw RPCException("Requested minimum protocol version (${rpcConfiguration.minimumServerProtocolVersion}) is higher" +
" than the server's supported protocol version ($serverProtocolVersion)")
}
proxyHandler.setServerProtocolVersion(serverProtocolVersion)
log.debug("RPC connected, returning proxy")
object : RPCConnection<I> {
override val proxy = ops
override val serverProtocolVersion = serverProtocolVersion
override fun close() {
proxyHandler.close()
serverLocator.close()
}
}
}
}
}

View File

@ -0,0 +1,422 @@
package net.corda.client.rpc.internal
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.cache.Cache
import com.google.common.cache.CacheBuilder
import com.google.common.cache.RemovalCause
import com.google.common.cache.RemovalListener
import com.google.common.util.concurrent.SettableFuture
import com.google.common.util.concurrent.ThreadFactoryBuilder
import net.corda.core.ThreadBox
import net.corda.core.getOrThrow
import net.corda.core.messaging.RPCOps
import net.corda.core.random63BitValue
import net.corda.core.serialization.KryoPoolWithContext
import net.corda.core.utilities.*
import net.corda.nodeapi.*
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ServerLocator
import rx.Notification
import rx.Observable
import rx.subjects.UnicastSubject
import sun.reflect.CallerSensitive
import java.lang.reflect.InvocationHandler
import java.lang.reflect.Method
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import kotlin.collections.ArrayList
import kotlin.reflect.jvm.javaMethod
/**
* This class provides a proxy implementation of an RPC interface for RPC clients. It translates API calls to lower-level
* RPC protocol messages. For this protocol see [RPCApi].
*
* When a method is called on the interface the arguments are serialised and the request is forwarded to the server. The
* server then executes the code that implements the RPC and sends a reply.
*
* An RPC reply may contain [Observable]s, which are serialised simply as unique IDs. On the client side we create a
* [UnicastSubject] for each such ID. Subsequently the server may send observations attached to this ID, which are
* forwarded to the [UnicastSubject]. Note that the observations themselves may contain further [Observable]s, which are
* handled in the same way.
*
* To do the above we take advantage of Kryo's datastructure traversal. When the client is deserialising a message from
* the server that may contain Observables it is supplied with an [ObservableContext] that exposes the map used to demux
* the observations. When an [Observable] is encountered during traversal a new [UnicastSubject] is added to the map and
* we carry on. Each observation later contains the corresponding Observable ID, and we just forward that to the
* associated [UnicastSubject].
*
* The client may signal that it no longer consumes a particular [Observable]. This may be done explicitly by
* unsubscribing from the [Observable], or if the [Observable] is garbage collected the client will eventually
* automatically signal the server. This is done using a cache that holds weak references to the [UnicastSubject]s.
* The cleanup happens in batches using a dedicated reaper, scheduled on [reaperExecutor].
*/
class RPCClientProxyHandler(
private val rpcConfiguration: RPCClientConfiguration,
private val rpcUsername: String,
private val rpcPassword: String,
private val serverLocator: ServerLocator,
private val clientAddress: SimpleString,
private val rpcOpsClass: Class<out RPCOps>
) : InvocationHandler {
private enum class State {
UNSTARTED,
SERVER_VERSION_NOT_SET,
STARTED,
FINISHED
}
private val lifeCycle = LifeCycle(State.UNSTARTED)
private companion object {
val log = loggerFor<RPCClientProxyHandler>()
// Note that this KryoPool is not yet capable of deserialising Observables, it requires Proxy-specific context
// to do that. However it may still be used for serialisation of RPC requests and related messages.
val kryoPool = KryoPool.Builder { RPCKryo(RpcClientObservableSerializer) }.build()
// To check whether toString() is being invoked
val toStringMethod: Method = Object::toString.javaMethod!!
}
// Used for reaping
private val reaperExecutor = Executors.newScheduledThreadPool(
1,
ThreadFactoryBuilder().setNameFormat("rpc-client-reaper-%d").build()
)
// A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering.
private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").build()
private val observationExecutorPool = LazyStickyPool(rpcConfiguration.observationExecutorPoolSize) {
Executors.newFixedThreadPool(1, observationExecutorThreadFactory)
}
// Holds the RPC reply futures.
private val rpcReplyMap = RpcReplyMap()
// Optionally holds RPC call site stack traces to be shown on errors/warnings.
private val callSiteMap = if (rpcConfiguration.trackRpcCallSites) CallSiteMap() else null
// Holds the Observables and a reference store to keep Observables alive when subscribed to.
private val observableContext = ObservableContext(
callSiteMap = callSiteMap,
observableMap = createRpcObservableMap(),
hardReferenceStore = Collections.synchronizedSet(mutableSetOf<Observable<*>>())
)
// Holds a reference to the scheduled reaper.
private lateinit var reaperScheduledFuture: ScheduledFuture<*>
// The protocol version of the server, to be initialised to the value of [RPCOps.protocolVersion]
private var serverProtocolVersion: Int? = null
// Stores the Observable IDs that are already removed from the map but are not yet sent to the server.
private val observablesToReap = ThreadBox(object {
var observables = ArrayList<RPCApi.ObservableId>()
})
// A Kryo pool that automatically adds the observable context when an instance is requested.
private val kryoPoolWithObservableContext = RpcClientObservableSerializer.createPoolWithContext(kryoPool, observableContext)
private fun createRpcObservableMap(): RpcObservableMap {
val onObservableRemove = RemovalListener<RPCApi.ObservableId, UnicastSubject<Notification<Any>>> {
val rpcCallSite = callSiteMap?.remove(it.key.toLong)
if (it.cause == RemovalCause.COLLECTED) {
log.warn(listOf(
"A hot observable returned from an RPC was never subscribed to.",
"This wastes server-side resources because it was queueing observations for retrieval.",
"It is being closed now, but please adjust your code to call .notUsed() on the observable",
"to close it explicitly. (Java users: subscribe to it then unsubscribe). This warning",
"will appear less frequently in future versions of the platform and you can ignore it",
"if you want to.").joinToString(" "), rpcCallSite)
}
observablesToReap.locked { observables.add(it.key) }
}
return CacheBuilder.newBuilder().
weakValues().
removalListener(onObservableRemove).
concurrencyLevel(rpcConfiguration.cacheConcurrencyLevel).
build()
}
// We cannot pool consumers as we need to preserve the original muxed message order.
// TODO We may need to pool these somehow anyway, otherwise if the server sends many big messages in parallel a
// single consumer may be starved for flow control credits. Recheck this once Artemis's large message streaming is
// integrated properly.
private lateinit var sessionAndConsumer: ArtemisConsumer
// Pool producers to reduce contention on the client side.
private val sessionAndProducerPool = LazyPool(bound = rpcConfiguration.producerPoolBound) {
// Note how we create new sessions *and* session factories per producer.
// We cannot simply pool producers on one session because sessions are single threaded.
// We cannot simply pool sessions on one session factory because flow control credits are tied to factories, so
// sessions tend to starve each other when used concurrently.
val sessionFactory = serverLocator.createSessionFactory()
val session = sessionFactory.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
session.start()
ArtemisProducer(sessionFactory, session, session.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME))
}
/**
* Start the client. This creates the per-client queue, starts the consumer session and the reaper.
*/
fun start() {
reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate(
this::reapObservables,
rpcConfiguration.reapInterval.toMillis(),
rpcConfiguration.reapInterval.toMillis(),
TimeUnit.MILLISECONDS
)
sessionAndProducerPool.run {
it.session.createTemporaryQueue(clientAddress, clientAddress)
}
val sessionFactory = serverLocator.createSessionFactory()
val session = sessionFactory.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
val consumer = session.createConsumer(clientAddress)
consumer.setMessageHandler(this@RPCClientProxyHandler::artemisMessageHandler)
sessionAndConsumer = ArtemisConsumer(sessionFactory, session, consumer)
lifeCycle.transition(State.UNSTARTED, State.SERVER_VERSION_NOT_SET)
session.start()
}
// This is the general function that transforms a client side RPC to internal Artemis messages.
override fun invoke(proxy: Any, method: Method, arguments: Array<out Any?>?): Any? {
lifeCycle.requireState { it == State.STARTED || it == State.SERVER_VERSION_NOT_SET }
checkProtocolVersion(method)
if (method == toStringMethod) {
return "Client RPC proxy for $rpcOpsClass"
}
if (sessionAndConsumer.session.isClosed) {
throw RPCException("RPC Proxy is closed")
}
val rpcId = RPCApi.RpcRequestId(random63BitValue())
callSiteMap?.set(rpcId.toLong, Throwable("<Call site of root RPC '${method.name}'>"))
try {
val request = RPCApi.ClientToServer.RpcRequest(clientAddress, rpcId, method.name, arguments?.toList() ?: emptyList())
val replyFuture = SettableFuture.create<Any>()
sessionAndProducerPool.run {
val message = it.session.createMessage(false)
request.writeToClientMessage(kryoPool, message)
log.debug {
val argumentsString = arguments?.joinToString() ?: ""
"-> RPC($rpcId) -> ${method.name}($argumentsString): ${method.returnType}"
}
require(rpcReplyMap.put(rpcId, replyFuture) == null) {
"Generated several RPC requests with same ID $rpcId"
}
it.producer.send(message)
it.session.commit()
}
return replyFuture.getOrThrow()
} finally {
callSiteMap?.remove(rpcId.toLong)
}
}
// The handler for Artemis messages.
private fun artemisMessageHandler(message: ClientMessage) {
val serverToClient = RPCApi.ServerToClient.fromClientMessage(kryoPoolWithObservableContext, message)
log.debug { "Got message from RPC server $serverToClient" }
when (serverToClient) {
is RPCApi.ServerToClient.RpcReply -> {
val replyFuture = rpcReplyMap.remove(serverToClient.id)
if (replyFuture == null) {
log.error("RPC reply arrived to unknown RPC ID ${serverToClient.id}, this indicates an internal RPC error.")
} else {
val rpcCallSite = callSiteMap?.get(serverToClient.id.toLong)
serverToClient.result.match(
onError = {
if (rpcCallSite != null) addRpcCallSiteToThrowable(it, rpcCallSite)
replyFuture.setException(it)
},
onValue = { replyFuture.set(it) }
)
}
}
is RPCApi.ServerToClient.Observation -> {
val observable = observableContext.observableMap.getIfPresent(serverToClient.id)
if (observable == null) {
log.debug("Observation ${serverToClient.content} arrived to unknown Observable with ID ${serverToClient.id}. " +
"This may be due to an observation arriving before the server was " +
"notified of observable shutdown")
} else {
// We schedule the onNext() on an executor sticky-pooled based on the Observable ID.
observationExecutorPool.run(serverToClient.id) { executor ->
executor.submit {
val content = serverToClient.content
if (content.isOnCompleted || content.isOnError) {
observableContext.observableMap.invalidate(serverToClient.id)
}
// Add call site information on error
if (content.isOnError) {
val rpcCallSite = callSiteMap?.get(serverToClient.id.toLong)
if (rpcCallSite != null) addRpcCallSiteToThrowable(content.throwable, rpcCallSite)
}
observable.onNext(content)
}
}
}
}
}
message.acknowledge()
}
/**
* Closes the RPC proxy. Reaps all observables, shuts down the reaper, closes all sessions and executors.
*/
fun close() {
sessionAndConsumer.consumer.close()
sessionAndConsumer.session.close()
sessionAndConsumer.sessionFactory.close()
reaperScheduledFuture.cancel(false)
observableContext.observableMap.invalidateAll()
reapObservables()
reaperExecutor.shutdownNow()
sessionAndProducerPool.close().forEach {
it.producer.close()
it.session.close()
it.sessionFactory.close()
}
// Note the ordering is important, we shut down the consumer *before* the observation executor, otherwise we may
// leak borrowed executors.
val observationExecutors = observationExecutorPool.close()
observationExecutors.forEach { it.shutdownNow() }
observationExecutors.forEach { it.awaitTermination(100, TimeUnit.MILLISECONDS) }
lifeCycle.transition(State.STARTED, State.FINISHED)
}
/**
* Check the [RPCSinceVersion] of the passed in [calledMethod] against the server's protocol version.
*/
private fun checkProtocolVersion(calledMethod: Method) {
val serverProtocolVersion = serverProtocolVersion
if (serverProtocolVersion == null) {
lifeCycle.requireState(State.SERVER_VERSION_NOT_SET)
} else {
lifeCycle.requireState(State.STARTED)
val sinceVersion = calledMethod.getAnnotation(RPCSinceVersion::class.java)?.version ?: 0
if (sinceVersion > serverProtocolVersion) {
throw UnsupportedOperationException("Method $calledMethod was added in RPC protocol version $sinceVersion but the server is running $serverProtocolVersion")
}
}
}
/**
* Set the server's protocol version. Note that before doing so the client is not considered fully started, although
* RPCs already may be called with it.
*/
internal fun setServerProtocolVersion(version: Int) {
if (serverProtocolVersion == null) {
serverProtocolVersion = version
} else {
throw IllegalStateException("setServerProtocolVersion called, but the protocol version was already set!")
}
lifeCycle.transition(State.SERVER_VERSION_NOT_SET, State.STARTED)
}
private fun reapObservables() {
observableContext.observableMap.cleanUp()
val observableIds = observablesToReap.locked {
if (observables.isNotEmpty()) {
val temporary = observables
observables = ArrayList()
temporary
} else {
null
}
}
if (observableIds != null) {
log.debug { "Reaping ${observableIds.size} observables" }
sessionAndProducerPool.run {
val message = it.session.createMessage(false)
RPCApi.ClientToServer.ObservablesClosed(observableIds).writeToClientMessage(message)
it.producer.send(message)
}
}
}
}
private typealias RpcObservableMap = Cache<RPCApi.ObservableId, UnicastSubject<Notification<Any>>>
private typealias RpcReplyMap = ConcurrentHashMap<RPCApi.RpcRequestId, SettableFuture<Any?>>
private typealias CallSiteMap = ConcurrentHashMap<Long, Throwable?>
/**
* Holds a context available during Kryo deserialisation of messages that are expected to contain Observables.
*
* @param observableMap holds the Observables that are ultimately exposed to the user.
* @param hardReferenceStore holds references to Observables we want to keep alive while they are subscribed to.
*/
private data class ObservableContext(
val callSiteMap: CallSiteMap?,
val observableMap: RpcObservableMap,
val hardReferenceStore: MutableSet<Observable<*>>
)
/**
* A [Serializer] to deserialise Observables once the corresponding Kryo instance has been provided with an [ObservableContext].
*/
private object RpcClientObservableSerializer : Serializer<Observable<Any>>() {
private object RpcObservableContextKey
fun createPoolWithContext(kryoPool: KryoPool, observableContext: ObservableContext): KryoPool {
return KryoPoolWithContext(kryoPool, RpcObservableContextKey, observableContext)
}
override fun read(kryo: Kryo, input: Input, type: Class<Observable<Any>>): Observable<Any> {
@Suppress("UNCHECKED_CAST")
val observableContext = kryo.context[RpcObservableContextKey] as ObservableContext
val observableId = RPCApi.ObservableId(input.readLong(true))
val observable = UnicastSubject.create<Notification<Any>>()
require(observableContext.observableMap.getIfPresent(observableId) == null) {
"Multiple Observables arrived with the same ID $observableId"
}
val rpcCallSite = getRpcCallSite(kryo, observableContext)
observableContext.observableMap.put(observableId, observable)
observableContext.callSiteMap?.put(observableId.toLong, rpcCallSite)
// We pin all Observables into a hard reference store (rooted in the RPC proxy) on subscription so that users
// don't need to store a reference to the Observables themselves.
return observable.pinInSubscriptions(observableContext.hardReferenceStore).doOnUnsubscribe {
// This causes Future completions to give warnings because the corresponding OnComplete sent from the server
// will arrive after the client unsubscribes from the observable and consequently invalidates the mapping.
// The unsubscribe is due to [ObservableToFuture]'s use of first().
observableContext.observableMap.invalidate(observableId)
}.dematerialize()
}
override fun write(kryo: Kryo, output: Output, observable: Observable<Any>) {
throw UnsupportedOperationException("Cannot serialise Observables on the client side")
}
private fun getRpcCallSite(kryo: Kryo, observableContext: ObservableContext): Throwable? {
val rpcRequestOrObservableId = kryo.context[RPCApi.RpcRequestOrObservableIdKey] as Long
return observableContext.callSiteMap?.get(rpcRequestOrObservableId)
}
}
private fun addRpcCallSiteToThrowable(throwable: Throwable, callSite: Throwable) {
var currentThrowable = throwable
while (true) {
val cause = currentThrowable.cause
if (cause == null) {
currentThrowable.initCause(callSite)
break
} else {
currentThrowable = cause
}
}
}
private fun <T> Observable<T>.pinInSubscriptions(hardReferenceStore: MutableSet<Observable<*>>): Observable<T> {
val refCount = AtomicInteger(0)
return this.doOnSubscribe {
if (refCount.getAndIncrement() == 0) {
require(hardReferenceStore.add(this)) { "Reference store already contained reference $this on add" }
}
}.doOnUnsubscribe {
if (refCount.decrementAndGet() == 0) {
require(hardReferenceStore.remove(this)) { "Reference store did not contain reference $this on remove" }
}
}
}

View File

@ -1,102 +0,0 @@
package net.corda.client.rpc
import net.corda.core.messaging.RPCOps
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.ALICE
import net.corda.core.utilities.LogHelper
import net.corda.node.services.RPCUserService
import net.corda.node.services.messaging.RPCDispatcher
import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.ArtemisMessagingComponent
import net.corda.nodeapi.User
import org.apache.activemq.artemis.api.core.Message
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.api.core.client.ActiveMQClient
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientProducer
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.apache.activemq.artemis.core.config.impl.ConfigurationImpl
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMAcceptorFactory
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnectorFactory
import org.apache.activemq.artemis.core.server.embedded.EmbeddedActiveMQ
import org.junit.After
import org.junit.Before
import java.util.*
import java.util.concurrent.locks.ReentrantLock
abstract class AbstractClientRPCTest {
lateinit var artemis: EmbeddedActiveMQ
lateinit var serverSession: ClientSession
lateinit var clientSession: ClientSession
lateinit var producer: ClientProducer
lateinit var serverThread: AffinityExecutor.ServiceAffinityExecutor
@Before
fun rpcSetup() {
// Set up an in-memory Artemis with an RPC requests queue.
artemis = EmbeddedActiveMQ()
artemis.setConfiguration(ConfigurationImpl().apply {
acceptorConfigurations = setOf(TransportConfiguration(InVMAcceptorFactory::class.java.name))
isSecurityEnabled = false
isPersistenceEnabled = false
})
artemis.start()
val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(TransportConfiguration(InVMConnectorFactory::class.java.name))
val sessionFactory = serverLocator.createSessionFactory()
serverSession = sessionFactory.createSession()
serverSession.start()
serverSession.createTemporaryQueue(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE, ArtemisMessagingComponent.RPC_REQUESTS_QUEUE)
producer = serverSession.createProducer()
serverThread = AffinityExecutor.ServiceAffinityExecutor("unit-tests-rpc-dispatch-thread", 1)
serverSession.createTemporaryQueue("activemq.notifications", "rpc.qremovals", "_AMQ_NotifType = 'BINDING_REMOVED'")
clientSession = sessionFactory.createSession()
clientSession.start()
LogHelper.setLevel("+net.corda.rpc")
}
@After
fun rpcShutdown() {
safeClose(producer)
clientSession.stop()
serverSession.stop()
artemis.stop()
serverThread.shutdownNow()
}
fun <T : RPCOps> rpcProxyFor(rpcUser: User, rpcImpl: T, type: Class<T>): T {
val userService = object : RPCUserService {
override fun getUser(username: String): User? = if (username == rpcUser.username) rpcUser else null
override val users: List<User> get() = listOf(rpcUser)
}
val dispatcher = object : RPCDispatcher(rpcImpl, userService, ALICE.name) {
override fun send(data: SerializedBytes<*>, toAddress: String) {
val msg = serverSession.createMessage(false).apply {
writeBodyBufferBytes(data.bytes)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
}
producer.send(toAddress, msg)
}
override fun getUser(message: ClientMessage): User = rpcUser
}
val serverNotifConsumer = serverSession.createConsumer("rpc.qremovals")
val serverConsumer = serverSession.createConsumer(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE)
dispatcher.start(serverConsumer, serverNotifConsumer, serverThread)
return CordaRPCClientImpl(clientSession, ReentrantLock(), rpcUser.username).proxyFor(type)
}
fun safeClose(obj: Any) {
try {
(obj as AutoCloseable).close()
} catch (e: Exception) {
}
}
}

View File

@ -0,0 +1,56 @@
package net.corda.client.rpc
import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.flatMap
import net.corda.core.map
import net.corda.core.messaging.RPCOps
import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.User
import net.corda.testing.RPCDriverExposedDSLInterface
import net.corda.testing.rpcTestUser
import net.corda.testing.startInVmRpcClient
import net.corda.testing.startRpcClient
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.junit.runners.Parameterized
open class AbstractRPCTest {
enum class RPCTestMode {
InVm,
Netty
}
companion object {
@JvmStatic @Parameterized.Parameters(name = "Mode = {0}")
fun defaultModes() = modes(RPCTestMode.InVm, RPCTestMode.Netty)
fun modes(vararg modes: RPCTestMode) = listOf(*modes).map { arrayOf(it) }
}
@Parameterized.Parameter
lateinit var mode: RPCTestMode
data class TestProxy<out I : RPCOps>(
val ops: I,
val createSession: () -> ClientSession
)
inline fun <reified I : RPCOps> RPCDriverExposedDSLInterface.testProxy(
ops: I,
rpcUser: User = rpcTestUser,
clientConfiguration: RPCClientConfiguration = RPCClientConfiguration.default,
serverConfiguration: RPCServerConfiguration = RPCServerConfiguration.default
): TestProxy<I> {
return when (mode) {
RPCTestMode.InVm ->
startInVmRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap {
startInVmRpcClient<I>(rpcUser.username, rpcUser.password, clientConfiguration).map {
TestProxy(it, { startInVmArtemisSession(rpcUser.username, rpcUser.password) })
}
}.get()
RPCTestMode.Netty ->
startRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap { server ->
startRpcClient<I>(server.hostAndPort, rpcUser.username, rpcUser.password, clientConfiguration).map {
TestProxy(it, { startArtemisSession(server.hostAndPort, rpcUser.username, rpcUser.password) })
}
}.get()
}
}
}

View File

@ -5,16 +5,16 @@ import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import net.corda.core.getOrThrow
import net.corda.core.messaging.RPCOps
import net.corda.core.messaging.RPCReturnsObservables
import net.corda.core.success
import net.corda.nodeapi.CURRENT_RPC_USER
import net.corda.node.services.messaging.getRpcContext
import net.corda.nodeapi.RPCSinceVersion
import net.corda.nodeapi.User
import org.apache.activemq.artemis.api.core.SimpleString
import net.corda.testing.RPCDriverExposedDSLInterface
import net.corda.testing.rpcDriver
import net.corda.testing.rpcTestUser
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import rx.Observable
import rx.subjects.PublishSubject
import java.util.concurrent.CountDownLatch
@ -23,22 +23,11 @@ import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
class ClientRPCInfrastructureTests : AbstractClientRPCTest() {
@RunWith(Parameterized::class)
class ClientRPCInfrastructureTests : AbstractRPCTest() {
// TODO: Test that timeouts work
lateinit var proxy: TestOps
private val authenticatedUser = User("test", "password", permissions = setOf())
@Before
fun setup() {
proxy = rpcProxyFor(authenticatedUser, TestOpsImpl(), TestOps::class.java)
}
@After
fun shutdown() {
safeClose(proxy)
}
private fun RPCDriverExposedDSLInterface.testProxy() = testProxy<TestOps>(TestOpsImpl()).ops
interface TestOps : RPCOps {
@Throws(IllegalArgumentException::class)
@ -48,16 +37,12 @@ class ClientRPCInfrastructureTests : AbstractClientRPCTest() {
fun someCalculation(str: String, num: Int): String
@RPCReturnsObservables
fun makeObservable(): Observable<Int>
@RPCReturnsObservables
fun makeComplicatedObservable(): Observable<Pair<String, Observable<String>>>
@RPCReturnsObservables
fun makeListenableFuture(): ListenableFuture<Int>
@RPCReturnsObservables
fun makeComplicatedListenableFuture(): ListenableFuture<Pair<String, ListenableFuture<String>>>
@RPCSinceVersion(2)
@ -78,117 +63,130 @@ class ClientRPCInfrastructureTests : AbstractClientRPCTest() {
override fun makeListenableFuture(): ListenableFuture<Int> = Futures.immediateFuture(1)
override fun makeComplicatedObservable() = complicatedObservable
override fun makeComplicatedListenableFuture(): ListenableFuture<Pair<String, ListenableFuture<String>>> = complicatedListenableFuturee
override fun addedLater(): Unit = throw UnsupportedOperationException("not implemented")
override fun captureUser(): String = CURRENT_RPC_USER.get().username
override fun addedLater(): Unit = throw IllegalStateException()
override fun captureUser(): String = getRpcContext().currentUser.username
}
@Test
fun `simple RPCs`() {
// Does nothing, doesn't throw.
proxy.void()
rpcDriver {
val proxy = testProxy()
// Does nothing, doesn't throw.
proxy.void()
assertEquals("Barf!", assertFailsWith<IllegalArgumentException> {
proxy.barf()
}.message)
assertEquals("Barf!", assertFailsWith<IllegalArgumentException> {
proxy.barf()
}.message)
assertEquals("hi 5", proxy.someCalculation("hi", 5))
assertEquals("hi 5", proxy.someCalculation("hi", 5))
}
}
@Test
fun `simple observable`() {
// This tests that the observations are transmitted correctly, also completion is transmitted.
val observations = proxy.makeObservable().toBlocking().toIterable().toList()
assertEquals(listOf(1, 2, 3, 4), observations)
rpcDriver {
val proxy = testProxy()
// This tests that the observations are transmitted correctly, also completion is transmitted.
val observations = proxy.makeObservable().toBlocking().toIterable().toList()
assertEquals(listOf(1, 2, 3, 4), observations)
}
}
@Test
fun `complex observables`() {
// This checks that we can return an object graph with complex usage of observables, like an observable
// that emits objects that contain more observables.
val serverQuotes = PublishSubject.create<Pair<String, Observable<String>>>()
val unsubscribeLatch = CountDownLatch(1)
complicatedObservable = serverQuotes.asObservable().doOnUnsubscribe { unsubscribeLatch.countDown() }
rpcDriver {
val proxy = testProxy()
// This checks that we can return an object graph with complex usage of observables, like an observable
// that emits objects that contain more observables.
val serverQuotes = PublishSubject.create<Pair<String, Observable<String>>>()
val unsubscribeLatch = CountDownLatch(1)
complicatedObservable = serverQuotes.asObservable().doOnUnsubscribe { unsubscribeLatch.countDown() }
val twainQuotes = "Mark Twain" to Observable.just(
"I have never let my schooling interfere with my education.",
"Clothes make the man. Naked people have little or no influence on society."
)
val wildeQuotes = "Oscar Wilde" to Observable.just(
"I can resist everything except temptation.",
"Always forgive your enemies - nothing annoys them so much."
)
val twainQuotes = "Mark Twain" to Observable.just(
"I have never let my schooling interfere with my education.",
"Clothes make the man. Naked people have little or no influence on society."
)
val wildeQuotes = "Oscar Wilde" to Observable.just(
"I can resist everything except temptation.",
"Always forgive your enemies - nothing annoys them so much."
)
val clientQuotes = LinkedBlockingQueue<String>()
val clientObs = proxy.makeComplicatedObservable()
val clientQuotes = LinkedBlockingQueue<String>()
val clientObs = proxy.makeComplicatedObservable()
val subscription = clientObs.subscribe {
val name = it.first
it.second.subscribe {
clientQuotes += "Quote by $name: $it"
val subscription = clientObs.subscribe {
val name = it.first
it.second.subscribe {
clientQuotes += "Quote by $name: $it"
}
}
assertThat(clientQuotes).isEmpty()
serverQuotes.onNext(twainQuotes)
assertEquals("Quote by Mark Twain: I have never let my schooling interfere with my education.", clientQuotes.take())
assertEquals("Quote by Mark Twain: Clothes make the man. Naked people have little or no influence on society.", clientQuotes.take())
serverQuotes.onNext(wildeQuotes)
assertEquals("Quote by Oscar Wilde: I can resist everything except temptation.", clientQuotes.take())
assertEquals("Quote by Oscar Wilde: Always forgive your enemies - nothing annoys them so much.", clientQuotes.take())
assertTrue(serverQuotes.hasObservers())
subscription.unsubscribe()
unsubscribeLatch.await()
}
val rpcQueuesQuery = SimpleString("clients.${authenticatedUser.username}.rpc.*")
assertEquals(2, clientSession.addressQuery(rpcQueuesQuery).queueNames.size)
assertThat(clientQuotes).isEmpty()
serverQuotes.onNext(twainQuotes)
assertEquals("Quote by Mark Twain: I have never let my schooling interfere with my education.", clientQuotes.take())
assertEquals("Quote by Mark Twain: Clothes make the man. Naked people have little or no influence on society.", clientQuotes.take())
serverQuotes.onNext(wildeQuotes)
assertEquals("Quote by Oscar Wilde: I can resist everything except temptation.", clientQuotes.take())
assertEquals("Quote by Oscar Wilde: Always forgive your enemies - nothing annoys them so much.", clientQuotes.take())
assertTrue(serverQuotes.hasObservers())
subscription.unsubscribe()
unsubscribeLatch.await()
assertEquals(1, clientSession.addressQuery(rpcQueuesQuery).queueNames.size)
}
@Test
fun `simple ListenableFuture`() {
val value = proxy.makeListenableFuture().getOrThrow()
assertThat(value).isEqualTo(1)
rpcDriver {
val proxy = testProxy()
val value = proxy.makeListenableFuture().getOrThrow()
assertThat(value).isEqualTo(1)
}
}
@Test
fun `complex ListenableFuture`() {
val serverQuote = SettableFuture.create<Pair<String, ListenableFuture<String>>>()
complicatedListenableFuturee = serverQuote
rpcDriver {
val proxy = testProxy()
val serverQuote = SettableFuture.create<Pair<String, ListenableFuture<String>>>()
complicatedListenableFuturee = serverQuote
val twainQuote = "Mark Twain" to Futures.immediateFuture("I have never let my schooling interfere with my education.")
val twainQuote = "Mark Twain" to Futures.immediateFuture("I have never let my schooling interfere with my education.")
val clientQuotes = LinkedBlockingQueue<String>()
val clientFuture = proxy.makeComplicatedListenableFuture()
val clientQuotes = LinkedBlockingQueue<String>()
val clientFuture = proxy.makeComplicatedListenableFuture()
clientFuture.success {
val name = it.first
it.second.success {
clientQuotes += "Quote by $name: $it"
clientFuture.success {
val name = it.first
it.second.success {
clientQuotes += "Quote by $name: $it"
}
}
assertThat(clientQuotes).isEmpty()
serverQuote.set(twainQuote)
assertThat(clientQuotes.take()).isEqualTo("Quote by Mark Twain: I have never let my schooling interfere with my education.")
// TODO This final assert sometimes fails because the relevant queue hasn't been removed yet
}
val rpcQueuesQuery = SimpleString("clients.${authenticatedUser.username}.rpc.*")
assertEquals(2, clientSession.addressQuery(rpcQueuesQuery).queueNames.size)
assertThat(clientQuotes).isEmpty()
serverQuote.set(twainQuote)
assertThat(clientQuotes.take()).isEqualTo("Quote by Mark Twain: I have never let my schooling interfere with my education.")
// TODO This final assert sometimes fails because the relevant queue hasn't been removed yet
// assertEquals(1, clientSession.addressQuery(rpcQueuesQuery).queueNames.size)
}
@Test
fun versioning() {
assertFailsWith<UnsupportedOperationException> { proxy.addedLater() }
rpcDriver {
val proxy = testProxy()
assertFailsWith<UnsupportedOperationException> { proxy.addedLater() }
}
}
@Test
fun `authenticated user is available to RPC`() {
assertThat(proxy.captureUser()).isEqualTo(authenticatedUser.username)
rpcDriver {
val proxy = testProxy()
assertThat(proxy.captureUser()).isEqualTo(rpcTestUser.username)
}
}
}

View File

@ -0,0 +1,182 @@
package net.corda.client.rpc
import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.future
import net.corda.core.messaging.RPCOps
import net.corda.core.millis
import net.corda.core.random63BitValue
import net.corda.core.serialization.CordaSerializable
import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.testing.RPCDriverExposedDSLInterface
import net.corda.testing.rpcDriver
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import rx.Observable
import rx.subjects.UnicastSubject
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CountDownLatch
@RunWith(Parameterized::class)
class RPCConcurrencyTests : AbstractRPCTest() {
/**
* Holds a "rose"-tree of [Observable]s which allows us to test arbitrary [Observable] nesting in RPC replies.
*/
@CordaSerializable
data class ObservableRose<out A>(val value: A, val branches: Observable<out ObservableRose<A>>)
private interface TestOps : RPCOps {
fun newLatch(numberOfDowns: Int): Long
fun waitLatch(id: Long)
fun downLatch(id: Long)
fun getImmediateObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int>
fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int>
}
class TestOpsImpl : TestOps {
private val latches = ConcurrentHashMap<Long, CountDownLatch>()
override val protocolVersion = 0
override fun newLatch(numberOfDowns: Int): Long {
val id = random63BitValue()
val latch = CountDownLatch(numberOfDowns)
latches.put(id, latch)
return id
}
override fun waitLatch(id: Long) {
latches[id]!!.await()
}
override fun downLatch(id: Long) {
latches[id]!!.countDown()
}
override fun getImmediateObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int> {
val branches = if (depth == 0) {
Observable.empty<ObservableRose<Int>>()
} else {
Observable.just(getImmediateObservableTree(depth - 1, branchingFactor)).repeat(branchingFactor.toLong())
}
return ObservableRose(depth, branches)
}
override fun getParallelObservableTree(depth: Int, branchingFactor: Int): ObservableRose<Int> {
val branches = if (depth == 0) {
Observable.empty<ObservableRose<Int>>()
} else {
val publish = UnicastSubject.create<ObservableRose<Int>>()
future {
(1..branchingFactor).toList().parallelStream().forEach {
publish.onNext(getParallelObservableTree(depth - 1, branchingFactor))
}
publish.onCompleted()
}
publish
}
return ObservableRose(depth, branches)
}
}
private lateinit var testOpsImpl: TestOpsImpl
private fun RPCDriverExposedDSLInterface.testProxy(): TestProxy<TestOps> {
testOpsImpl = TestOpsImpl()
return testProxy<TestOps>(
testOpsImpl,
clientConfiguration = RPCClientConfiguration.default.copy(
reapInterval = 100.millis,
cacheConcurrencyLevel = 16
),
serverConfiguration = RPCServerConfiguration.default.copy(
rpcThreadPoolSize = 4
)
)
}
@Test
fun `call multiple RPCs in parallel`() {
rpcDriver {
val proxy = testProxy()
val numberOfBlockedCalls = 2
val numberOfDownsRequired = 100
val id = proxy.ops.newLatch(numberOfDownsRequired)
val done = CountDownLatch(numberOfBlockedCalls)
// Start a couple of blocking RPC calls
(1..numberOfBlockedCalls).forEach {
future {
proxy.ops.waitLatch(id)
done.countDown()
}
}
// Down the latch that the others are waiting for concurrently
(1..numberOfDownsRequired).toList().parallelStream().forEach {
proxy.ops.downLatch(id)
}
done.await()
}
}
private fun intPower(base: Int, power: Int): Int {
return when (power) {
0 -> 1
1 -> base
else -> {
val a = intPower(base, power / 2)
if (power and 1 == 0) {
a * a
} else {
a * a * base
}
}
}
}
@Test
fun `nested immediate observables sequence correctly`() {
rpcDriver {
// We construct a rose tree of immediate Observables and check that parent observations arrive before children.
val proxy = testProxy()
val treeDepth = 6
val treeBranchingFactor = 3
val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1))
val depthsSeen = Collections.synchronizedSet(HashSet<Int>())
fun ObservableRose<Int>.subscribeToAll() {
remainingLatch.countDown()
this.branches.subscribe { tree ->
(tree.value + 1..treeDepth - 1).forEach {
require(it in depthsSeen) { "Got ${tree.value} before $it" }
}
depthsSeen.add(tree.value)
tree.subscribeToAll()
}
}
proxy.ops.getImmediateObservableTree(treeDepth, treeBranchingFactor).subscribeToAll()
remainingLatch.await()
}
}
@Test
fun `parallel nested observables`() {
rpcDriver {
val proxy = testProxy()
val treeDepth = 2
val treeBranchingFactor = 10
val remainingLatch = CountDownLatch((intPower(treeBranchingFactor, treeDepth + 1) - 1) / (treeBranchingFactor - 1))
val depthsSeen = Collections.synchronizedSet(HashSet<Int>())
fun ObservableRose<Int>.subscribeToAll() {
remainingLatch.countDown()
branches.subscribe { tree ->
(tree.value + 1..treeDepth - 1).forEach {
require(it in depthsSeen) { "Got ${tree.value} before $it" }
}
depthsSeen.add(tree.value)
tree.subscribeToAll()
}
}
proxy.ops.getParallelObservableTree(treeDepth, treeBranchingFactor).subscribeToAll()
remainingLatch.await()
}
}
}

View File

@ -0,0 +1,315 @@
package net.corda.client.rpc
import com.codahale.metrics.ConsoleReporter
import com.codahale.metrics.Gauge
import com.codahale.metrics.JmxReporter
import com.codahale.metrics.MetricRegistry
import com.google.common.base.Stopwatch
import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.messaging.RPCOps
import net.corda.core.minutes
import net.corda.core.seconds
import net.corda.core.utilities.Rate
import net.corda.core.utilities.div
import net.corda.node.driver.ShutdownManager
import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.testing.RPCDriverExposedDSLInterface
import net.corda.testing.measure
import net.corda.testing.rpcDriver
import org.junit.Ignore
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import java.time.Duration
import java.util.*
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.ReentrantLock
import javax.management.ObjectName
import kotlin.concurrent.thread
import kotlin.concurrent.withLock
@Ignore("Only use this locally for profiling")
@RunWith(Parameterized::class)
class RPCPerformanceTests : AbstractRPCTest() {
companion object {
@JvmStatic @Parameterized.Parameters(name = "Mode = {0}")
fun modes() = modes(RPCTestMode.Netty)
}
private interface TestOps : RPCOps {
fun simpleReply(input: ByteArray, sizeOfReply: Int): ByteArray
}
class TestOpsImpl : TestOps {
override val protocolVersion = 0
override fun simpleReply(input: ByteArray, sizeOfReply: Int): ByteArray {
return ByteArray(sizeOfReply)
}
}
private fun RPCDriverExposedDSLInterface.testProxy(
clientConfiguration: RPCClientConfiguration,
serverConfiguration: RPCServerConfiguration
): TestProxy<TestOps> {
return testProxy<TestOps>(
TestOpsImpl(),
clientConfiguration = clientConfiguration,
serverConfiguration = serverConfiguration
)
}
private fun warmup() {
rpcDriver {
val proxy = testProxy(
RPCClientConfiguration.default,
RPCServerConfiguration.default
)
val executor = Executors.newFixedThreadPool(4)
val N = 10000
val latch = CountDownLatch(N)
for (i in 1 .. N) {
executor.submit {
proxy.ops.simpleReply(ByteArray(1024), 1024)
latch.countDown()
}
}
latch.await()
}
}
data class SimpleRPCResult(
val requestPerSecond: Double,
val averageIndividualMs: Double,
val Mbps: Double
)
@Test
fun `measure Megabytes per second for simple RPCs`() {
warmup()
val inputOutputSizes = listOf(1024, 4096, 100 * 1024)
val overallTraffic = 512 * 1024 * 1024L
measure(inputOutputSizes, (1..5)) { inputOutputSize, N ->
rpcDriver {
val proxy = testProxy(
RPCClientConfiguration.default.copy(
cacheConcurrencyLevel = 16,
observationExecutorPoolSize = 2,
producerPoolBound = 2
),
RPCServerConfiguration.default.copy(
rpcThreadPoolSize = 8,
consumerPoolSize = 2,
producerPoolBound = 8
)
)
val numberOfRequests = overallTraffic / (2 * inputOutputSize)
val timings = Collections.synchronizedList(ArrayList<Long>())
val executor = Executors.newFixedThreadPool(8)
val totalElapsed = Stopwatch.createStarted().apply {
startInjectorWithBoundedQueue(
executor = executor,
numberOfInjections = numberOfRequests.toInt(),
queueBound = 100
) {
val elapsed = Stopwatch.createStarted().apply {
proxy.ops.simpleReply(ByteArray(inputOutputSize), inputOutputSize)
}.stop().elapsed(TimeUnit.MICROSECONDS)
timings.add(elapsed)
}
}.stop().elapsed(TimeUnit.MICROSECONDS)
executor.shutdownNow()
SimpleRPCResult(
requestPerSecond = 1000000.0 * numberOfRequests.toDouble() / totalElapsed.toDouble(),
averageIndividualMs = timings.average() / 1000.0,
Mbps = (overallTraffic.toDouble() / totalElapsed.toDouble()) * (1000000.0 / (1024.0 * 1024.0))
)
}
}.forEach(::println)
}
/**
* Runs 20k RPCs per second for two minutes and publishes relevant stats to JMX.
*/
@Test
fun `consumption rate`() {
rpcDriver {
val metricRegistry = startReporter()
val proxy = testProxy(
RPCClientConfiguration.default.copy(
reapInterval = 1.seconds,
cacheConcurrencyLevel = 16,
producerPoolBound = 8
),
RPCServerConfiguration.default.copy(
rpcThreadPoolSize = 8,
consumerPoolSize = 1,
producerPoolBound = 8
)
)
measurePerformancePublishMetrics(
metricRegistry = metricRegistry,
parallelism = 8,
overallDuration = 5.minutes,
injectionRate = 20000L / TimeUnit.SECONDS,
queueSizeMetricName = "$mode.QueueSize",
workDurationMetricName = "$mode.WorkDuration",
shutdownManager = this.shutdownManager,
work = {
proxy.ops.simpleReply(ByteArray(4096), 4096)
}
)
}
}
data class BigMessagesResult(
val Mbps: Double
)
@Test
fun `big messages`() {
warmup()
measure(listOf(1)) { clientParallelism -> // TODO this hangs with more parallelism
rpcDriver {
val proxy = testProxy(
RPCClientConfiguration.default,
RPCServerConfiguration.default.copy(
consumerPoolSize = 1
)
)
val executor = Executors.newFixedThreadPool(clientParallelism)
val numberOfMessages = 1000
val bigSize = 10_000_000
val elapsed = Stopwatch.createStarted().apply {
startInjectorWithBoundedQueue(
executor = executor,
numberOfInjections = numberOfMessages,
queueBound = 4
) {
proxy.ops.simpleReply(ByteArray(bigSize), 0)
}
}.stop().elapsed(TimeUnit.MICROSECONDS)
executor.shutdownNow()
BigMessagesResult(
Mbps = bigSize.toDouble() * numberOfMessages.toDouble() / elapsed * (1000000.0 / (1024.0 * 1024.0))
)
}
}.forEach(::println)
}
}
fun measurePerformancePublishMetrics(
metricRegistry: MetricRegistry,
parallelism: Int,
overallDuration: Duration,
injectionRate: Rate,
queueSizeMetricName: String,
workDurationMetricName: String,
shutdownManager: ShutdownManager,
work: () -> Unit
) {
val workSemaphore = Semaphore(0)
metricRegistry.register(queueSizeMetricName, Gauge { workSemaphore.availablePermits() })
val workDurationTimer = metricRegistry.timer(workDurationMetricName)
val executor = Executors.newSingleThreadScheduledExecutor()
val workExecutor = Executors.newFixedThreadPool(parallelism)
val timings = Collections.synchronizedList(ArrayList<Long>())
for (i in 1 .. parallelism) {
workExecutor.submit {
try {
while (true) {
workSemaphore.acquire()
workDurationTimer.time {
timings.add(
Stopwatch.createStarted().apply {
work()
}.stop().elapsed(TimeUnit.MICROSECONDS)
)
}
}
} catch (throwable: Throwable) {
throwable.printStackTrace()
}
}
}
val injector = executor.scheduleAtFixedRate(
{
workSemaphore.release((injectionRate * TimeUnit.SECONDS).toInt())
},
0,
1,
TimeUnit.SECONDS
)
shutdownManager.registerShutdown {
injector.cancel(true)
workExecutor.shutdownNow()
executor.shutdownNow()
workExecutor.awaitTermination(1, TimeUnit.SECONDS)
executor.awaitTermination(1, TimeUnit.SECONDS)
}
Thread.sleep(overallDuration.toMillis())
}
fun startInjectorWithBoundedQueue(
executor: ExecutorService,
numberOfInjections: Int,
queueBound: Int,
work: () -> Unit
) {
val remainingLatch = CountDownLatch(numberOfInjections)
val queuedCount = AtomicInteger(0)
val lock = ReentrantLock()
val canQueueAgain = lock.newCondition()
val injectorShutdown = AtomicBoolean(false)
val injector = thread(name = "injector") {
while (true) {
if (injectorShutdown.get()) break
executor.submit {
work()
if (queuedCount.decrementAndGet() < queueBound / 2) {
lock.withLock {
canQueueAgain.signal()
}
}
remainingLatch.countDown()
}
if (queuedCount.incrementAndGet() > queueBound) {
lock.withLock {
canQueueAgain.await()
}
}
}
}
remainingLatch.await()
injectorShutdown.set(true)
injector.join()
}
fun RPCDriverExposedDSLInterface.startReporter(): MetricRegistry {
val metricRegistry = MetricRegistry()
val jmxReporter = thread {
JmxReporter.
forRegistry(metricRegistry).
inDomain("net.corda").
createsObjectNamesWith { _, domain, name ->
// Make the JMX hierarchy a bit better organised.
val category = name.substringBefore('.')
val subName = name.substringAfter('.', "")
if (subName == "")
ObjectName("$domain:name=$category")
else
ObjectName("$domain:type=$category,name=$subName")
}.
build().
start()
}
val consoleReporter = thread {
ConsoleReporter.forRegistry(metricRegistry).build().start(1, TimeUnit.SECONDS)
}
shutdownManager.registerShutdown {
jmxReporter.interrupt()
consoleReporter.interrupt()
jmxReporter.join()
consoleReporter.join()
}
return metricRegistry
}

View File

@ -1,85 +0,0 @@
package net.corda.client.rpc
import net.corda.core.messaging.RPCOps
import net.corda.node.services.messaging.requirePermission
import net.corda.nodeapi.PermissionException
import net.corda.nodeapi.User
import org.junit.After
import org.junit.Test
import kotlin.test.assertFailsWith
class RPCPermissionsTest : AbstractClientRPCTest() {
companion object {
const val DUMMY_FLOW = "StartFlow.net.corda.flows.DummyFlow"
const val OTHER_FLOW = "StartFlow.net.corda.flows.OtherFlow"
const val ALL_ALLOWED = "ALL"
}
lateinit var proxy: TestOps
@After
fun shutdown() {
safeClose(proxy)
}
/*
* RPC operation.
*/
interface TestOps : RPCOps {
fun validatePermission(str: String)
}
class TestOpsImpl : TestOps {
override val protocolVersion = 1
override fun validatePermission(str: String) = requirePermission(str)
}
/**
* Create an RPC proxy for the given user.
*/
private fun proxyFor(rpcUser: User): TestOps = rpcProxyFor(rpcUser, TestOpsImpl(), TestOps::class.java)
private fun userOf(name: String, permissions: Set<String>) = User(name, "password", permissions)
@Test
fun `empty user cannot use any flows`() {
val emptyUser = userOf("empty", emptySet())
proxy = proxyFor(emptyUser)
assertFailsWith(PermissionException::class,
"User ${emptyUser.username} should not be allowed to use $DUMMY_FLOW.",
{ proxy.validatePermission(DUMMY_FLOW) })
}
@Test
fun `admin user can use any flow`() {
val adminUser = userOf("admin", setOf(ALL_ALLOWED))
proxy = proxyFor(adminUser)
proxy.validatePermission(DUMMY_FLOW)
}
@Test
fun `joe user is allowed to use DummyFlow`() {
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
proxy = proxyFor(joeUser)
proxy.validatePermission(DUMMY_FLOW)
}
@Test
fun `joe user is not allowed to use OtherFlow`() {
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
proxy = proxyFor(joeUser)
assertFailsWith(PermissionException::class,
"User ${joeUser.username} should not be allowed to use $OTHER_FLOW",
{ proxy.validatePermission(OTHER_FLOW) })
}
@Test
fun `check ALL is implemented the correct way round`() {
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
proxy = proxyFor(joeUser)
assertFailsWith(PermissionException::class,
"Permission $ALL_ALLOWED should not do anything for User ${joeUser.username}",
{ proxy.validatePermission(ALL_ALLOWED) })
}
}

View File

@ -0,0 +1,93 @@
package net.corda.client.rpc
import net.corda.core.messaging.RPCOps
import net.corda.node.services.messaging.requirePermission
import net.corda.node.services.messaging.getRpcContext
import net.corda.nodeapi.PermissionException
import net.corda.nodeapi.User
import net.corda.testing.RPCDriverExposedDSLInterface
import net.corda.testing.rpcDriver
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import kotlin.test.assertFailsWith
@RunWith(Parameterized::class)
class RPCPermissionsTests : AbstractRPCTest() {
companion object {
const val DUMMY_FLOW = "StartFlow.net.corda.flows.DummyFlow"
const val OTHER_FLOW = "StartFlow.net.corda.flows.OtherFlow"
const val ALL_ALLOWED = "ALL"
}
/*
* RPC operation.
*/
interface TestOps : RPCOps {
fun validatePermission(str: String)
}
class TestOpsImpl : TestOps {
override val protocolVersion = 1
override fun validatePermission(str: String) = getRpcContext().requirePermission(str)
}
/**
* Create an RPC proxy for the given user.
*/
private fun RPCDriverExposedDSLInterface.testProxyFor(rpcUser: User) = testProxy<TestOps>(TestOpsImpl(), rpcUser).ops
private fun userOf(name: String, permissions: Set<String>) = User(name, "password", permissions)
@Test
fun `empty user cannot use any flows`() {
rpcDriver {
val emptyUser = userOf("empty", emptySet())
val proxy = testProxyFor(emptyUser)
assertFailsWith(PermissionException::class,
"User ${emptyUser.username} should not be allowed to use $DUMMY_FLOW.",
{ proxy.validatePermission(DUMMY_FLOW) })
}
}
@Test
fun `admin user can use any flow`() {
rpcDriver {
val adminUser = userOf("admin", setOf(ALL_ALLOWED))
val proxy = testProxyFor(adminUser)
proxy.validatePermission(DUMMY_FLOW)
}
}
@Test
fun `joe user is allowed to use DummyFlow`() {
rpcDriver {
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
val proxy = testProxyFor(joeUser)
proxy.validatePermission(DUMMY_FLOW)
}
}
@Test
fun `joe user is not allowed to use OtherFlow`() {
rpcDriver {
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
val proxy = testProxyFor(joeUser)
assertFailsWith(PermissionException::class,
"User ${joeUser.username} should not be allowed to use $OTHER_FLOW",
{ proxy.validatePermission(OTHER_FLOW) })
}
}
@Test
fun `check ALL is implemented the correct way round` () {
rpcDriver {
val joeUser = userOf("joe", setOf(DUMMY_FLOW))
val proxy = testProxyFor(joeUser)
assertFailsWith(PermissionException::class,
"Permission $ALL_ALLOWED should not do anything for User ${joeUser.username}",
{ proxy.validatePermission(ALL_ALLOWED) })
}
}
}

View File

@ -0,0 +1,25 @@
package net.corda.client.rpc
import java.io.InputStream
class RepeatingBytesInputStream(val bytesToRepeat: ByteArray, val numberOfBytes: Int) : InputStream() {
private var bytesLeft = numberOfBytes
override fun available() = bytesLeft
override fun read(): Int {
if (bytesLeft == 0) {
return -1
} else {
bytesLeft--
return bytesToRepeat[(numberOfBytes - bytesLeft) % bytesToRepeat.size].toInt()
}
}
override fun read(byteArray: ByteArray, offset: Int, length: Int): Int {
val until = Math.min(Math.min(offset + length, byteArray.size), offset + bytesLeft)
for (i in offset .. until - 1) {
byteArray[i] = bytesToRepeat[(numberOfBytes - bytesLeft + i - offset) % bytesToRepeat.size]
}
val bytesRead = until - offset
bytesLeft -= bytesRead
return if (bytesRead == 0 && bytesLeft == 0) -1 else bytesRead
}
}

View File

@ -404,6 +404,8 @@ data class ErrorOr<out A> private constructor(val value: A?, val error: Throwabl
ErrorOr.of(error)
}
}
fun mapError(function: (Throwable) -> Throwable) = ErrorOr(value, error?.let(function))
}
/**

View File

@ -3,6 +3,7 @@ package net.corda.core.serialization
import com.esotericsoftware.kryo.*
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoCallback
import com.esotericsoftware.kryo.pool.KryoPool
import com.esotericsoftware.kryo.util.MapReferenceResolver
import com.google.common.annotations.VisibleForTesting
@ -10,6 +11,7 @@ import net.corda.core.contracts.*
import net.corda.core.crypto.*
import net.corda.core.node.AttachmentsClassLoader
import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.LazyPool
import net.i2p.crypto.eddsa.EdDSAPrivateKey
import net.i2p.crypto.eddsa.EdDSAPublicKey
import net.i2p.crypto.eddsa.spec.EdDSANamedCurveSpec
@ -19,7 +21,9 @@ import org.bouncycastle.asn1.ASN1InputStream
import org.bouncycastle.asn1.x500.X500Name
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.io.*
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.InputStream
import java.lang.reflect.InvocationTargetException
import java.nio.file.Files
import java.nio.file.Path
@ -143,13 +147,26 @@ fun <T : Any> T.serialize(kryo: KryoPool = p2PKryo(), internalOnly: Boolean = fa
return kryo.run { k -> serialize(k, internalOnly) }
}
private val serializeBufferPool = LazyPool(
newInstance = { ByteArray(64 * 1024) }
)
private val serializeOutputStreamPool = LazyPool(
clear = ByteArrayOutputStream::reset,
shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large
newInstance = { ByteArrayOutputStream(64 * 1024) }
)
fun <T : Any> T.serialize(kryo: Kryo, internalOnly: Boolean = false): SerializedBytes<T> {
val stream = ByteArrayOutputStream()
Output(stream).use {
it.writeBytes(KryoHeaderV0_1.bytes)
kryo.writeClassAndObject(it, this)
return serializeOutputStreamPool.run { stream ->
serializeBufferPool.run { buffer ->
Output(buffer).use {
it.outputStream = stream
it.writeBytes(KryoHeaderV0_1.bytes)
kryo.writeClassAndObject(it, this)
}
SerializedBytes(stream.toByteArray(), internalOnly)
}
}
return SerializedBytes(stream.toByteArray(), internalOnly)
}
/**
@ -592,4 +609,26 @@ object X500NameSerializer : Serializer<X500Name>() {
override fun write(kryo: Kryo, output: Output, obj: X500Name) {
output.writeBytes(obj.encoded)
}
}
}
class KryoPoolWithContext(val baseKryoPool: KryoPool, val contextKey: Any, val context: Any) : KryoPool {
override fun <T : Any?> run(callback: KryoCallback<T>): T {
val kryo = borrow()
try {
return callback.execute(kryo)
} finally {
release(kryo)
}
}
override fun borrow(): Kryo {
val kryo = baseKryoPool.borrow()
require(kryo.context.put(contextKey, context) == null) { "KryoPool already has context" }
return kryo
}
override fun release(kryo: Kryo) {
requireNotNull(kryo.context.remove(contextKey)) { "Kryo instance lost context while borrowed" }
baseKryoPool.release(kryo)
}
}

View File

@ -0,0 +1,74 @@
package net.corda.core.utilities
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.Semaphore
/**
* A lazy pool of resources [A].
*
* @param clear If specified this function will be run on each borrowed instance before handing it over.
* @param shouldReturnToPool If specified this function will be run on each release to determine whether the instance
* should be returned to the pool for reuse. This may be useful for pooled resources that dynamically grow during
* usage, and we may not want to retain them forever.
* @param bound If specified the pool will be bounded. Once all instances are borrowed subsequent borrows will block until an
* instance is released.
* @param newInstance The function to call to lazily newInstance a pooled resource.
*/
class LazyPool<A>(
private val clear: ((A) -> Unit)? = null,
private val shouldReturnToPool: ((A) -> Boolean)? = null,
private val bound: Int? = null,
private val newInstance: () -> A
) {
private val poolQueue = ConcurrentLinkedQueue<A>()
private val poolSemaphore = Semaphore(bound ?: Int.MAX_VALUE)
private enum class State {
STARTED,
FINISHED
}
private val lifeCycle = LifeCycle(State.STARTED)
private fun clearIfNeeded(instance: A): A {
clear?.invoke(instance)
return instance
}
fun borrow(): A {
lifeCycle.requireState(State.STARTED)
poolSemaphore.acquire()
val pooled = poolQueue.poll()
if (pooled == null) {
return newInstance()
} else {
return clearIfNeeded(pooled)
}
}
fun release(instance: A) {
lifeCycle.requireState(State.STARTED)
if (shouldReturnToPool == null || shouldReturnToPool.invoke(instance)) {
poolQueue.add(instance)
}
poolSemaphore.release()
}
/**
* Closes the pool. Note that all borrowed instances must have been released before calling this function, otherwise
* the returned iterable will be inaccurate.
*/
fun close(): Iterable<A> {
lifeCycle.transition(State.STARTED, State.FINISHED)
return poolQueue
}
inline fun <R> run(withInstance: (A) -> R): R {
val instance = borrow()
try {
return withInstance(instance)
} finally {
release(instance)
}
}
}

View File

@ -0,0 +1,70 @@
package net.corda.core.utilities
import java.util.*
import java.util.concurrent.LinkedBlockingQueue
/**
* A [LazyStickyPool] is a lazy pool of resources where a [borrow] may "stick" the borrowed instance to an object.
* Any subsequent borrows using the same object will return the same pooled instance.
*
* @param size The size of the pool.
* @param newInstance The function to call to create a pooled resource.
*/
// TODO This could be implemented more efficiently. Currently the "non-sticky" use case is not optimised, it just chooses a random instance to wait on.
class LazyStickyPool<A : Any>(
size: Int,
private val newInstance: () -> A
) {
private class InstanceBox<A> {
var instance: LinkedBlockingQueue<A>? = null
}
private val random = Random()
private val boxes = Array(size) { InstanceBox<A>() }
private fun toIndex(stickTo: Any): Int {
return Math.abs(stickTo.hashCode()) % boxes.size
}
fun borrow(stickTo: Any): A {
val box = boxes[toIndex(stickTo)]
val instance = synchronized(box) {
val instance = box.instance
if (instance == null) {
val newInstance = LinkedBlockingQueue(listOf(newInstance()))
box.instance = newInstance
newInstance
} else {
instance
}
}
return instance.take()
}
fun borrow(): Pair<Any, A> {
val randomInt = random.nextInt()
val instance = borrow(randomInt)
return Pair(randomInt, instance)
}
fun release(stickTo: Any, instance: A) {
val box = boxes[toIndex(stickTo)]
box.instance!!.add(instance)
}
inline fun <R> run(stickToOrNull: Any? = null, withInstance: (A) -> R): R {
val (stickTo, instance) = if (stickToOrNull == null) {
borrow()
} else {
Pair(stickToOrNull, borrow(stickToOrNull))
}
try {
return withInstance(instance)
} finally {
release(stickTo, instance)
}
}
fun close(): Iterable<A> {
return boxes.map { it.instance?.poll() }.filterNotNull()
}
}

View File

@ -0,0 +1,38 @@
package net.corda.core.utilities
import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.withLock
/**
* This class provides simple tracking of the lifecycle of a service-type object.
* [S] is an enum enumerating the possible states the service can be in.
*
* @param initial The initial state.
*/
class LifeCycle<S : Enum<S>>(initial: S) {
private val lock = ReentrantReadWriteLock()
private var state = initial
/** Assert that the lifecycle in the [requiredState] */
fun requireState(requiredState: S) {
requireState({ "Required state to be $requiredState, was $it" }) { it == requiredState }
}
/** Assert something about the current state atomically. */
fun requireState(
errorMessage: (S) -> String = { "Predicate failed on state $it" },
predicate: (S) -> Boolean
) {
lock.readLock().withLock {
require(predicate(state)) { errorMessage(state) }
}
}
/** Transition the state from [from] to [to] */
fun transition(from: S, to: S) {
lock.writeLock().withLock {
require(state == from) { "Required state to be $from to transition to $to, was $state" }
state = to
}
}
}

View File

@ -0,0 +1,29 @@
package net.corda.core.utilities
import java.time.Duration
import java.time.temporal.ChronoUnit
import java.util.concurrent.TimeUnit
/**
* [Rate] holds a quantity denoting the frequency of some event e.g. 100 times per second or 2 times per day.
*/
data class Rate(
val numberOfEvents: Long,
val perTimeUnit: TimeUnit
) {
/**
* Returns the interval between two subsequent events.
*/
fun toInterval(): Duration {
return Duration.of(TimeUnit.NANOSECONDS.convert(1, perTimeUnit) / numberOfEvents, ChronoUnit.NANOS)
}
/**
* Converts the number of events to the given unit.
*/
operator fun times(inUnit: TimeUnit): Long {
return inUnit.convert(numberOfEvents, perTimeUnit)
}
}
operator fun Long.div(timeUnit: TimeUnit) = Rate(this, timeUnit)

View File

@ -5,6 +5,7 @@ import net.corda.core.contracts.*
import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.getOrThrow
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.startFlow
import net.corda.core.node.services.unconsumedStates
import net.corda.core.serialization.OpaqueBytes
@ -15,9 +16,12 @@ import net.corda.flows.FinalityFlow
import net.corda.node.internal.CordaRPCOpsImpl
import net.corda.node.services.startFlowPermission
import net.corda.node.utilities.transaction
import net.corda.nodeapi.CURRENT_RPC_USER
import net.corda.nodeapi.User
import net.corda.testing.RPCDriverExposedDSLInterface
import net.corda.testing.node.MockNetwork
import net.corda.testing.rpcDriver
import net.corda.testing.rpcTestUser
import net.corda.testing.startRpcClient
import org.junit.After
import org.junit.Before
import org.junit.Test
@ -99,60 +103,71 @@ class ContractUpgradeFlowTest {
check(b)
}
private fun RPCDriverExposedDSLInterface.startProxy(node: MockNetwork.MockNode, user: User): CordaRPCOps {
return startRpcClient<CordaRPCOps>(
rpcAddress = startRpcServer(
rpcUser = user,
ops = CordaRPCOpsImpl(node.services, node.smm, node.database)
).get().hostAndPort,
username = user.username,
password = user.password
).get()
}
@Test
fun `2 parties contract upgrade using RPC`() {
// Create dummy contract.
val twoPartyDummyContract = DummyContract.generateInitial(0, notary, a.info.legalIdentity.ref(1), b.info.legalIdentity.ref(1))
val stx = twoPartyDummyContract.signWith(a.services.legalIdentityKey)
.signWith(b.services.legalIdentityKey)
.toSignedTransaction()
rpcDriver {
// Create dummy contract.
val twoPartyDummyContract = DummyContract.generateInitial(0, notary, a.info.legalIdentity.ref(1), b.info.legalIdentity.ref(1))
val stx = twoPartyDummyContract.signWith(a.services.legalIdentityKey)
.signWith(b.services.legalIdentityKey)
.toSignedTransaction()
a.services.startFlow(FinalityFlow(stx, setOf(a.info.legalIdentity, b.info.legalIdentity)))
mockNet.runNetwork()
val user = rpcTestUser.copy(permissions = setOf(
startFlowPermission<FinalityFlow>(),
startFlowPermission<ContractUpgradeFlow<*, *>>()
))
val rpcA = startProxy(a, user)
val rpcB = startProxy(b, user)
val handle = rpcA.startFlow(::FinalityFlow, stx, setOf(a.info.legalIdentity, b.info.legalIdentity))
mockNet.runNetwork()
handle.returnValue.getOrThrow()
val atx = a.database.transaction { a.services.storageService.validatedTransactions.getTransaction(stx.id) }
val btx = b.database.transaction { b.services.storageService.validatedTransactions.getTransaction(stx.id) }
requireNotNull(atx)
requireNotNull(btx)
val atx = a.database.transaction { a.services.storageService.validatedTransactions.getTransaction(stx.id) }
val btx = b.database.transaction { b.services.storageService.validatedTransactions.getTransaction(stx.id) }
requireNotNull(atx)
requireNotNull(btx)
// The request is expected to be rejected because party B haven't authorise the upgrade yet.
val rejectedFuture = rpcA.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow(stateAndRef, upgrade) },
atx!!.tx.outRef<DummyContract.State>(0),
DummyContractV2::class.java).returnValue
val rpcA = CordaRPCOpsImpl(a.services, a.smm, a.database)
val rpcB = CordaRPCOpsImpl(b.services, b.smm, b.database)
mockNet.runNetwork()
assertFailsWith(ExecutionException::class) { rejectedFuture.get() }
CURRENT_RPC_USER.set(User("user", "pwd", permissions = setOf(
startFlowPermission<ContractUpgradeFlow<*, *>>()
)))
// Party B authorise the contract state upgrade.
rpcB.authoriseContractUpgrade(btx!!.tx.outRef<ContractState>(0), DummyContractV2::class.java)
val rejectedFuture = rpcA.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow(stateAndRef, upgrade) },
atx!!.tx.outRef<DummyContract.State>(0),
DummyContractV2::class.java).returnValue
// Party A initiates contract upgrade flow, expected to succeed this time.
val resultFuture = rpcA.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow(stateAndRef, upgrade) },
atx.tx.outRef<DummyContract.State>(0),
DummyContractV2::class.java).returnValue
mockNet.runNetwork()
assertFailsWith(ExecutionException::class) { rejectedFuture.get() }
mockNet.runNetwork()
val result = resultFuture.get()
// Check results.
listOf(a, b).forEach {
val signedTX = a.database.transaction { a.services.storageService.validatedTransactions.getTransaction(result.ref.txhash) }
requireNotNull(signedTX)
// Party B authorise the contract state upgrade.
rpcB.authoriseContractUpgrade(btx!!.tx.outRef<ContractState>(0), DummyContractV2::class.java)
// Verify inputs.
val input = a.database.transaction { a.services.storageService.validatedTransactions.getTransaction(signedTX!!.tx.inputs.single().txhash) }
requireNotNull(input)
assertTrue(input!!.tx.outputs.single().data is DummyContract.State)
// Party A initiates contract upgrade flow, expected to succeed this time.
val resultFuture = rpcA.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow(stateAndRef, upgrade) },
atx.tx.outRef<DummyContract.State>(0),
DummyContractV2::class.java).returnValue
mockNet.runNetwork()
val result = resultFuture.get()
// Check results.
listOf(a, b).forEach {
val signedTX = a.database.transaction { a.services.storageService.validatedTransactions.getTransaction(result.ref.txhash) }
requireNotNull(signedTX)
// Verify inputs.
val input = a.database.transaction { a.services.storageService.validatedTransactions.getTransaction(signedTX!!.tx.inputs.single().txhash) }
requireNotNull(input)
assertTrue(input!!.tx.outputs.single().data is DummyContract.State)
// Verify outputs.
assertTrue(signedTX!!.tx.outputs.single().data is DummyContractV2.State)
// Verify outputs.
assertTrue(signedTX!!.tx.outputs.single().data is DummyContractV2.State)
}
}
}

View File

@ -18,7 +18,10 @@ import net.corda.node.driver.driver
import net.corda.node.services.startFlowPermission
import net.corda.node.services.transactions.ValidatingNotaryService
import net.corda.nodeapi.User
import net.corda.testing.*
import net.corda.testing.expect
import net.corda.testing.expectEvents
import net.corda.testing.parallel
import net.corda.testing.sequence
import org.junit.Test
import java.util.*
import kotlin.concurrent.thread
@ -44,12 +47,10 @@ class IntegrationTestingTutorial {
// START 2
val aliceClient = alice.rpcClientToNode()
aliceClient.start("aliceUser", "testPassword1")
val aliceProxy = aliceClient.proxy()
val aliceProxy = aliceClient.start("aliceUser", "testPassword1").proxy
val bobClient = bob.rpcClientToNode()
bobClient.start("bobUser", "testPassword2")
val bobProxy = bobClient.proxy()
val bobProxy = bobClient.start("bobUser", "testPassword2").proxy
// END 2
// START 3

View File

@ -55,8 +55,7 @@ fun main(args: Array<String>) {
// START 2
val client = node.rpcClientToNode()
client.start("user", "password")
val proxy = client.proxy()
val proxy = client.start("user", "password").proxy
thread {
generateTransactions(proxy)

View File

@ -30,10 +30,7 @@ abstract class ArtemisMessagingComponent : SingletonSerializeAsToken() {
const val INTERNAL_PREFIX = "internal."
const val PEERS_PREFIX = "${INTERNAL_PREFIX}peers."
const val SERVICES_PREFIX = "${INTERNAL_PREFIX}services."
const val CLIENTS_PREFIX = "clients."
const val P2P_QUEUE = "p2p.inbound"
const val RPC_REQUESTS_QUEUE = "rpc.requests"
const val RPC_QUEUE_REMOVALS_QUEUE = "rpc.qremovals"
const val NOTIFICATIONS_ADDRESS = "${INTERNAL_PREFIX}activemq.notifications"
const val NETWORK_MAP_QUEUE = "${INTERNAL_PREFIX}networkmap"

View File

@ -0,0 +1,206 @@
package net.corda.nodeapi
import com.esotericsoftware.kryo.pool.KryoPool
import net.corda.core.ErrorOr
import net.corda.core.serialization.KryoPoolWithContext
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.nodeapi.RPCApi.ClientToServer
import net.corda.nodeapi.RPCApi.ObservableId
import net.corda.nodeapi.RPCApi.RPC_CLIENT_BINDING_REMOVALS
import net.corda.nodeapi.RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX
import net.corda.nodeapi.RPCApi.RPC_SERVER_QUEUE_NAME
import net.corda.nodeapi.RPCApi.RpcRequestId
import net.corda.nodeapi.RPCApi.ServerToClient
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.*
import org.apache.activemq.artemis.api.core.management.CoreNotificationType
import org.apache.activemq.artemis.api.core.management.ManagementHelper
import org.apache.activemq.artemis.reader.MessageUtil
import rx.Notification
import java.util.*
/**
* The RPC protocol:
*
* The server consumes the queue "[RPC_SERVER_QUEUE_NAME]" and receives RPC requests ([ClientToServer.RpcRequest]) on it.
* When a client starts up it should create a queue for its inbound messages, this should be of the form
* "[RPC_CLIENT_QUEUE_NAME_PREFIX].$username.$nonce". Each RPC request contains this address (in
* [ClientToServer.RpcRequest.clientAddress]), this is where the server will send the reply to the request as well as
* subsequent Observations rooted in the RPC. The requests/replies are muxed using a unique [RpcRequestId] generated by
* the client for each request.
*
* If an RPC reply's payload ([ServerToClient.RpcReply.result]) contains [Observable]s then the server will generate a
* unique [ObservableId] for each and serialise them in place of the [Observable]s themselves. Subsequently the client
* should be prepared to receive observations ([ServerToClient.Observation]), muxed by the relevant [ObservableId].
* In addition each observation itself may contain further [Observable]s, this case should behave the same as before.
*
* Additionally the client may send [ClientToServer.ObservablesClosed] messages indicating that certain observables
* aren't consumed anymore, which should subsequently stop the stream from the server. Note that some observations may
* already be in flight when this is sent, the client should handle this gracefully.
*
* An example session:
* Client Server
* ----------RpcRequest(RID0)-----------> // Client makes RPC request with ID "RID0"
* <----RpcReply(RID0, Payload(OID0))---- // Server sends reply containing an observable with ID "OID0"
* <---------Observation(OID0)----------- // Server sends observation onto "OID0"
* <---Observation(OID0, Payload(OID1))-- // Server sends another observation, this time containing another observable
* <---------Observation(OID1)----------- // Observation onto new "OID1"
* <---------Observation(OID0)-----------
* -----ObservablesClosed(OID0, OID1)---> // Client indicates it stopped consuming the Observables.
* <---------Observation(OID1)----------- // Observation was already in-flight before the previous message was processed
* (FIN)
*
* Note that multiple sessions like the above may interleave in an arbitrary fashion.
*
* Additionally the server may listen on client binding removals for cleanup using [RPC_CLIENT_BINDING_REMOVALS]. This
* requires the server to create a filter on the artemis notification address using [RPC_CLIENT_BINDING_REMOVAL_FILTER_EXPRESSION]
*/
object RPCApi {
private val TAG_FIELD_NAME = "tag"
private val RPC_ID_FIELD_NAME = "rpc-id"
private val OBSERVABLE_ID_FIELD_NAME = "observable-id"
private val METHOD_NAME_FIELD_NAME = "method-name"
val RPC_SERVER_QUEUE_NAME = "rpc.server"
val RPC_CLIENT_QUEUE_NAME_PREFIX = "rpc.client"
val RPC_CLIENT_BINDING_REMOVALS = "rpc.clientqueueremovals"
val RPC_CLIENT_BINDING_REMOVAL_FILTER_EXPRESSION =
"${ManagementHelper.HDR_NOTIFICATION_TYPE} = '${CoreNotificationType.BINDING_REMOVED.name}' AND " +
"${ManagementHelper.HDR_ROUTING_NAME} LIKE '$RPC_CLIENT_QUEUE_NAME_PREFIX.%'"
data class RpcRequestId(val toLong: Long)
data class ObservableId(val toLong: Long)
object RpcRequestOrObservableIdKey
private fun ClientMessage.getBodyAsByteArray(): ByteArray {
return ByteArray(bodySize).apply { bodyBuffer.readBytes(this) }
}
sealed class ClientToServer {
private enum class Tag {
RPC_REQUEST,
OBSERVABLES_CLOSED
}
data class RpcRequest(
val clientAddress: SimpleString,
val id: RpcRequestId,
val methodName: String,
val arguments: List<Any?>
) : ClientToServer() {
fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) {
MessageUtil.setJMSReplyTo(message, clientAddress)
message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REQUEST.ordinal)
message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong)
message.putStringProperty(METHOD_NAME_FIELD_NAME, methodName)
message.bodyBuffer.writeBytes(arguments.serialize(kryoPool).bytes)
}
}
data class ObservablesClosed(
val ids: List<ObservableId>
) : ClientToServer() {
fun writeToClientMessage(message: ClientMessage) {
message.putIntProperty(TAG_FIELD_NAME, Tag.OBSERVABLES_CLOSED.ordinal)
val buffer = message.bodyBuffer
buffer.writeInt(ids.size)
ids.forEach {
buffer.writeLong(it.toLong)
}
}
}
companion object {
fun fromClientMessage(kryoPool: KryoPool, message: ClientMessage): ClientToServer {
val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)]
return when (tag) {
RPCApi.ClientToServer.Tag.RPC_REQUEST -> RpcRequest(
clientAddress = MessageUtil.getJMSReplyTo(message),
id = RpcRequestId(message.getLongProperty(RPC_ID_FIELD_NAME)),
methodName = message.getStringProperty(METHOD_NAME_FIELD_NAME),
arguments = message.getBodyAsByteArray().deserialize(kryoPool)
)
RPCApi.ClientToServer.Tag.OBSERVABLES_CLOSED -> {
val ids = ArrayList<ObservableId>()
val buffer = message.bodyBuffer
val numberOfIds = buffer.readInt()
for (i in 1 .. numberOfIds) {
ids.add(ObservableId(buffer.readLong()))
}
ObservablesClosed(ids)
}
}
}
}
}
sealed class ServerToClient {
private enum class Tag {
RPC_REPLY,
OBSERVATION
}
abstract fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage)
data class RpcReply(
val id: RpcRequestId,
val result: ErrorOr<Any?>
) : ServerToClient() {
override fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) {
message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REPLY.ordinal)
message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong)
message.bodyBuffer.writeBytes(result.serialize(kryoPool).bytes)
}
}
data class Observation(
val id: ObservableId,
val content: Notification<Any>
) : ServerToClient() {
override fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) {
message.putIntProperty(TAG_FIELD_NAME, Tag.OBSERVATION.ordinal)
message.putLongProperty(OBSERVABLE_ID_FIELD_NAME, id.toLong)
message.bodyBuffer.writeBytes(content.serialize(kryoPool).bytes)
}
}
companion object {
fun fromClientMessage(kryoPool: KryoPool, message: ClientMessage): ServerToClient {
val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)]
return when (tag) {
RPCApi.ServerToClient.Tag.RPC_REPLY -> {
val id = RpcRequestId(message.getLongProperty(RPC_ID_FIELD_NAME))
val poolWithIdContext = KryoPoolWithContext(kryoPool, RpcRequestOrObservableIdKey, id.toLong)
RpcReply(
id = id,
result = message.getBodyAsByteArray().deserialize(poolWithIdContext)
)
}
RPCApi.ServerToClient.Tag.OBSERVATION -> {
val id = ObservableId(message.getLongProperty(OBSERVABLE_ID_FIELD_NAME))
val poolWithIdContext = KryoPoolWithContext(kryoPool, RpcRequestOrObservableIdKey, id.toLong)
Observation(
id = id,
content = message.getBodyAsByteArray().deserialize(poolWithIdContext)
)
}
}
}
}
}
}
data class ArtemisProducer(
val sessionFactory: ClientSessionFactory,
val session: ClientSession,
val producer: ClientProducer
)
data class ArtemisConsumer(
val sessionFactory: ClientSessionFactory,
val session: ClientSession,
val consumer: ClientConsumer
)

View File

@ -10,17 +10,10 @@ import net.corda.core.serialization.*
import net.corda.core.toFuture
import net.corda.core.toObservable
import net.corda.nodeapi.config.OldConfig
import org.apache.commons.fileupload.MultipartStream
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import rx.Notification
import rx.Observable
/** Global RPC logger */
val rpcLog: Logger by lazy { LoggerFactory.getLogger("net.corda.rpc") }
/** Used in the RPC wire protocol to wrap an observation with the handle of the observable it's intended for. */
data class MarshalledObservation(val forHandle: Int, val what: Notification<*>)
import java.io.InputStream
import java.io.PrintWriter
import java.io.StringWriter
data class User(
@OldConfig("user")
@ -35,28 +28,6 @@ data class User(
@MustBeDocumented
annotation class RPCSinceVersion(val version: Int)
/** The contents of an RPC request message, separated from the MQ layer. */
data class ClientRPCRequestMessage(
val args: SerializedBytes<Array<Any>>,
val replyToAddress: String,
val observationsToAddress: String?,
val methodName: String,
val user: User
) {
companion object {
const val REPLY_TO = "reply-to"
const val OBSERVATIONS_TO = "observations-to"
const val METHOD_NAME = "method-name"
}
}
/**
* This is available to RPC implementations to query the validated [User] that is calling it. Each user has a set of
* permissions they're entitled to which can be used to control access.
*/
@JvmField
val CURRENT_RPC_USER: ThreadLocal<User> = ThreadLocal()
/**
* Thrown to indicate a fatal error in the RPC system itself, as opposed to an error generated by the invoked
* method.
@ -64,19 +35,11 @@ val CURRENT_RPC_USER: ThreadLocal<User> = ThreadLocal()
@CordaSerializable
open class RPCException(msg: String, cause: Throwable?) : RuntimeException(msg, cause) {
constructor(msg: String) : this(msg, null)
class DeadlineExceeded(rpcName: String) : RPCException("Deadline exceeded on call to $rpcName")
}
@CordaSerializable
class PermissionException(msg: String) : RuntimeException(msg)
object RPCKryoClientKey
object RPCKryoDispatcherKey
object RPCKryoQNameKey
object RPCKryoMethodNameKey
object RPCKryoLocationKey
// The Kryo used for the RPC wire protocol. Every type in the wire protocol is listed here explicitly.
// This is annoying to write out, but will make it easier to formalise the wire protocol when the time comes,
// because we can see everything we're using in one place.
@ -85,8 +48,7 @@ class RPCKryo(observableSerializer: Serializer<Observable<Any>>) : CordaKryo(mak
DefaultKryoCustomizer.customize(this)
// RPC specific classes
register(MultipartStream.ItemInputStream::class.java, InputStreamSerializer)
register(MarshalledObservation::class.java, ImmutableClassSerializer(MarshalledObservation::class))
register(InputStream::class.java, InputStreamSerializer)
register(Observable::class.java, observableSerializer)
@Suppress("UNCHECKED_CAST")
register(ListenableFuture::class,
@ -110,14 +72,14 @@ class RPCKryo(observableSerializer: Serializer<Observable<Any>>) : CordaKryo(mak
}
override fun getRegistration(type: Class<*>): Registration {
val annotated = context[RPCKryoQNameKey] != null
if (Observable::class.java.isAssignableFrom(type)) {
return if (annotated) super.getRegistration(Observable::class.java)
else throw IllegalStateException("This RPC was not annotated with @RPCReturnsObservables")
if (Observable::class.java != type && Observable::class.java.isAssignableFrom(type)) {
return super.getRegistration(Observable::class.java)
}
if (ListenableFuture::class.java.isAssignableFrom(type)) {
return if (annotated) super.getRegistration(ListenableFuture::class.java)
else throw IllegalStateException("This RPC was not annotated with @RPCReturnsObservables")
if (InputStream::class.java != type && InputStream::class.java.isAssignableFrom(type)) {
return super.getRegistration(InputStream::class.java)
}
if (ListenableFuture::class.java != type && ListenableFuture::class.java.isAssignableFrom(type)) {
return super.getRegistration(ListenableFuture::class.java)
}
if (FlowException::class.java.isAssignableFrom(type))
return super.getRegistration(FlowException::class.java)

View File

@ -17,9 +17,8 @@ class BootTests {
fun `java deserialization is disabled`() {
driver {
val user = User("u", "p", setOf(startFlowPermission<ObjectInputStreamFlow>()))
val future = startNode(rpcUsers = listOf(user)).getOrThrow().rpcClientToNode().apply {
start(user.username, user.password)
}.proxy().startFlow(::ObjectInputStreamFlow).returnValue
val future = startNode(rpcUsers = listOf(user)).getOrThrow().rpcClientToNode().
start(user.username, user.password).proxy.startFlow(::ObjectInputStreamFlow).returnValue
assertThatThrownBy { future.getOrThrow() }.isInstanceOf(InvalidClassException::class.java).hasMessage("filter status: REJECTED")
}
}

View File

@ -75,8 +75,10 @@ class DriverTests {
driver(isDebug = true, systemProperties = mapOf("log4j.configurationFile" to logConfigFile.toString())) {
val baseDirectory = startNode(DUMMY_BANK_A.name).getOrThrow().configuration.baseDirectory
val logFile = (baseDirectory / LOGS_DIRECTORY_NAME).list { it.sorted().findFirst().get() }
println("ASD $logFile")
val debugLinesPresent = logFile.readLines { lines -> lines.anyMatch { line -> line.startsWith("[DEBUG]") } }
assertThat(debugLinesPresent).isTrue()
println("hmm.")
}
}

View File

@ -60,8 +60,7 @@ class DistributedServiceTests : DriverBasedTest() {
// Connect to Alice and the notaries
fun connectRpc(node: NodeHandle): CordaRPCOps {
val client = node.rpcClientToNode()
client.start("test", "test")
return client.proxy()
return client.start("test", "test").proxy
}
aliceProxy = connectRpc(alice)
val rpcClientsToNotaries = notaries.map(::connectRpc)

View File

@ -2,7 +2,7 @@ package net.corda.services.messaging
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NODE_USER
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.PEER_USER
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.RPC_REQUESTS_QUEUE
import net.corda.nodeapi.RPCApi
import net.corda.testing.messaging.SimpleMQClient
import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration
import org.apache.activemq.artemis.api.core.ActiveMQClusterSecurityException
@ -24,7 +24,7 @@ class MQSecurityAsNodeTest : MQSecurityTest() {
@Test
fun `send message to RPC requests address`() {
assertSendAttackFails(RPC_REQUESTS_QUEUE)
assertSendAttackFails(RPCApi.RPC_SERVER_QUEUE_NAME)
}
@Test

View File

@ -1,6 +1,7 @@
package net.corda.services.messaging
import net.corda.nodeapi.User
import net.corda.testing.configureTestSSL
import net.corda.testing.messaging.SimpleMQClient
import org.apache.activemq.artemis.api.core.ActiveMQSecurityException
import org.assertj.core.api.Assertions.assertThatExceptionOfType
@ -23,14 +24,13 @@ class MQSecurityAsRPCTest : MQSecurityTest() {
override val extraRPCUsers = listOf(User("evil", "pass", permissions = emptySet()))
override fun startAttacker(attacker: SimpleMQClient) {
attacker.loginToRPC(extraRPCUsers[0])
attacker.start(extraRPCUsers[0].username, extraRPCUsers[0].password, false)
}
@Test
fun `login to a ssl port as a RPC user`() {
val attacker = clientTo(alice.configuration.p2pAddress)
assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy {
attacker.loginToRPC(extraRPCUsers[0], enableSSL = true)
loginToRPC(alice.configuration.p2pAddress, extraRPCUsers[0], configureTestSSL())
}
}
}

View File

@ -2,7 +2,7 @@ package net.corda.services.messaging
import co.paralleluniverse.fibers.Suspendable
import com.google.common.net.HostAndPort
import net.corda.client.rpc.CordaRPCClientImpl
import net.corda.client.rpc.CordaRPCClient
import net.corda.core.crypto.Party
import net.corda.core.crypto.generateKeyPair
import net.corda.core.crypto.toBase58String
@ -10,19 +10,16 @@ import net.corda.core.flows.FlowLogic
import net.corda.core.getOrThrow
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.random63BitValue
import net.corda.core.seconds
import net.corda.core.utilities.ALICE
import net.corda.core.utilities.BOB
import net.corda.core.utilities.unwrap
import net.corda.node.internal.Node
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.CLIENTS_PREFIX
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.INTERNAL_PREFIX
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NETWORK_MAP_QUEUE
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NOTIFICATIONS_ADDRESS
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.P2P_QUEUE
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.PEERS_PREFIX
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.RPC_QUEUE_REMOVALS_QUEUE
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.RPC_REQUESTS_QUEUE
import net.corda.nodeapi.RPCApi
import net.corda.nodeapi.User
import net.corda.nodeapi.config.SSLConfiguration
import net.corda.testing.configureTestSSL
@ -36,7 +33,6 @@ import org.junit.After
import org.junit.Before
import org.junit.Test
import java.util.*
import java.util.concurrent.locks.ReentrantLock
import kotlin.test.assertEquals
/**
@ -108,7 +104,7 @@ abstract class MQSecurityTest : NodeBasedTest() {
@Test
fun `consume message from RPC requests queue`() {
assertConsumeAttackFails(RPC_REQUESTS_QUEUE)
assertConsumeAttackFails(RPCApi.RPC_SERVER_QUEUE_NAME)
}
@Test
@ -119,21 +115,16 @@ abstract class MQSecurityTest : NodeBasedTest() {
@Test
fun `create queue for valid RPC user`() {
val user1Queue = "$CLIENTS_PREFIX${rpcUser.username}.rpc.${random63BitValue()}"
val user1Queue = "${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.${rpcUser.username}.${random63BitValue()}"
assertTempQueueCreationAttackFails(user1Queue)
}
@Test
fun `create queue for invalid RPC user`() {
val invalidRPCQueue = "$CLIENTS_PREFIX${random63BitValue()}.rpc.${random63BitValue()}"
val invalidRPCQueue = "${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.${random63BitValue()}.${random63BitValue()}"
assertTempQueueCreationAttackFails(invalidRPCQueue)
}
@Test
fun `consume message from RPC queue removals queue`() {
assertConsumeAttackFails(RPC_QUEUE_REMOVALS_QUEUE)
}
@Test
fun `send message to notifications address`() {
assertSendAttackFails(NOTIFICATIONS_ADDRESS)
@ -157,22 +148,16 @@ abstract class MQSecurityTest : NodeBasedTest() {
return client
}
fun loginToRPC(target: HostAndPort, rpcUser: User, sslConfiguration: SSLConfiguration? = null): SimpleMQClient {
val client = clientTo(target, sslConfiguration)
client.loginToRPC(rpcUser)
return client
}
fun SimpleMQClient.loginToRPC(rpcUser: User, enableSSL: Boolean = false): CordaRPCOps {
start(rpcUser.username, rpcUser.password, enableSSL)
val clientImpl = CordaRPCClientImpl(session, ReentrantLock(), rpcUser.username)
return clientImpl.proxyFor(CordaRPCOps::class.java, timeout = 1.seconds)
fun loginToRPC(target: HostAndPort, rpcUser: User, sslConfiguration: SSLConfiguration? = null): CordaRPCOps {
return CordaRPCClient(target, sslConfiguration).start(rpcUser.username, rpcUser.password).proxy
}
fun loginToRPCAndGetClientQueue(): String {
val rpcClient = loginToRPC(alice.configuration.rpcAddress!!, rpcUser)
val clientQueueQuery = SimpleString("$CLIENTS_PREFIX${rpcUser.username}.rpc.*")
return rpcClient.session.addressQuery(clientQueueQuery).queueNames.single().toString()
loginToRPC(alice.configuration.rpcAddress!!, rpcUser)
val clientQueueQuery = SimpleString("${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.${rpcUser.username}.*")
val client = clientTo(alice.configuration.rpcAddress!!)
client.start(rpcUser.username, rpcUser.password, false)
return client.session.addressQuery(clientQueueQuery).queueNames.single().toString()
}
fun assertAllQueueCreationAttacksFail(queue: String) {

View File

@ -7,13 +7,10 @@ import com.google.common.util.concurrent.*
import com.typesafe.config.Config
import com.typesafe.config.ConfigRenderOptions
import net.corda.client.rpc.CordaRPCClient
import net.corda.core.ThreadBox
import net.corda.core.*
import net.corda.core.crypto.Party
import net.corda.core.crypto.X509Utilities
import net.corda.core.crypto.commonName
import net.corda.core.div
import net.corda.core.flatMap
import net.corda.core.map
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.node.NodeInfo
import net.corda.core.node.services.ServiceInfo
@ -40,6 +37,7 @@ import java.io.File
import java.net.*
import java.nio.file.Path
import java.nio.file.Paths
import java.time.Duration
import java.time.Instant
import java.time.ZoneOffset.UTC
import java.time.format.DateTimeFormatter
@ -110,6 +108,26 @@ interface DriverDSLExposedInterface {
fun startNetworkMapService()
fun waitForAllNodesToFinish()
/**
* Polls a function until it returns a non-null value. Note that there is no timeout on the polling.
*
* @param pollName A description of what is being polled.
* @param pollInterval The interval of polling.
* @param warnCount The number of polls after the Driver gives a warning.
* @param check The function being polled.
* @return A future that completes with the non-null value [check] has returned.
*/
fun <A> pollUntilNonNull(pollName: String, pollInterval: Duration = 500.millis, warnCount: Int = 120, check: () -> A?): ListenableFuture<A>
/**
* Polls the given function until it returns true.
* @see pollUntilNonNull
*/
fun pollUntilTrue(pollName: String, pollInterval: Duration = 500.millis, warnCount: Int = 120, check: () -> Boolean): ListenableFuture<Unit> {
return pollUntilNonNull(pollName, pollInterval, warnCount) { if (check()) Unit else null }
}
val shutdownManager: ShutdownManager
}
interface DriverDSLInternalInterface : DriverDSLExposedInterface {
@ -216,15 +234,13 @@ fun <DI : DriverDSLExposedInterface, D : DriverDSLInternalInterface, A> genericD
var shutdownHook: Thread? = null
try {
driverDsl.start()
val returnValue = dsl(coerce(driverDsl))
shutdownHook = Thread({
driverDsl.shutdown()
})
Runtime.getRuntime().addShutdownHook(shutdownHook)
return returnValue
return dsl(coerce(driverDsl))
} catch (exception: Throwable) {
println("Driver shutting down because of exception $exception")
exception.printStackTrace()
log.error("Driver shutting down because of exception", exception)
throw exception
} finally {
driverDsl.shutdown()
@ -271,7 +287,7 @@ fun addressMustNotBeBound(executorService: ScheduledExecutorService, hostAndPort
fun <A> poll(
executorService: ScheduledExecutorService,
pollName: String,
pollIntervalMs: Long = 500,
pollInterval: Duration = 500.millis,
warnCount: Int = 120,
check: () -> A?
): ListenableFuture<A> {
@ -286,7 +302,7 @@ fun <A> poll(
executorService.schedule(task@ {
counter++
if (counter == warnCount) {
log.warn("Been polling $pollName for ${pollIntervalMs * warnCount / 1000.0} seconds...")
log.warn("Been polling $pollName for ${pollInterval.seconds * warnCount} seconds...")
}
val result = try {
check()
@ -299,7 +315,7 @@ fun <A> poll(
} else {
resultFuture.set(result)
}
}, pollIntervalMs, MILLISECONDS)
}, pollInterval.toMillis(), MILLISECONDS)
}
schedulePoll()
return resultFuture
@ -326,7 +342,13 @@ class ShutdownManager(private val executorService: ExecutorService) {
/** Could not get all of them, collect what we have */
shutdownFutures.filter { it.isDone }.map { it.get() }
}
shutdowns.reversed().forEach { it() }
shutdowns.reversed().forEach { shutdown ->
try {
shutdown()
} catch (throwable: Throwable) {
log.error("Exception while shutting down", throwable)
}
}
}
fun registerShutdown(shutdown: ListenableFuture<() -> Unit>) {
@ -335,6 +357,7 @@ class ShutdownManager(private val executorService: ExecutorService) {
registeredShutdowns.add(shutdown)
}
}
fun registerShutdown(shutdown: () -> Unit) = registerShutdown(Futures.immediateFuture(shutdown))
fun registerProcessShutdown(processFuture: ListenableFuture<Process>) {
val processShutdown = processFuture.map { process ->
@ -368,8 +391,10 @@ class DriverDSL(
) : DriverDSLInternalInterface {
private val networkMapLegalName = DUMMY_MAP.name
private val networkMapAddress = portAllocation.nextHostAndPort()
val executorService: ListeningScheduledExecutorService = MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool(2))
val shutdownManager = ShutdownManager(executorService)
val executorService: ListeningScheduledExecutorService = MoreExecutors.listeningDecorator(
Executors.newScheduledThreadPool(2, ThreadFactoryBuilder().setNameFormat("driver-pool-thread-%d").build())
)
override val shutdownManager = ShutdownManager(executorService)
class State {
val processes = ArrayList<ListenableFuture<Process>>()
@ -401,9 +426,6 @@ class DriverDSL(
override fun shutdown() {
shutdownManager.shutdown()
// Check that we shut down properly
addressMustNotBeBound(executorService, networkMapAddress).get()
executorService.shutdown()
}
@ -411,8 +433,9 @@ class DriverDSL(
val client = CordaRPCClient(nodeAddress, sslConfig)
return poll(executorService, "for RPC connection") {
try {
client.start(ArtemisMessagingComponent.NODE_USER, ArtemisMessagingComponent.NODE_USER)
return@poll client.proxy()
val connection = client.start(ArtemisMessagingComponent.NODE_USER, ArtemisMessagingComponent.NODE_USER)
shutdownManager.registerShutdown { connection.close() }
return@poll connection.proxy
} catch(e: Exception) {
log.error("Exception $e, Retrying RPC connection at $nodeAddress")
null
@ -566,6 +589,12 @@ class DriverDSL(
registerProcess(startNode)
}
override fun <A> pollUntilNonNull(pollName: String, pollInterval: Duration, warnCount: Int, check: () -> A?): ListenableFuture<A> {
val pollFuture = poll(executorService, pollName, pollInterval, warnCount, check)
shutdownManager.registerShutdown { pollFuture.cancel(true) }
return pollFuture
}
companion object {
val name = arrayOf(
ALICE.name,

View File

@ -97,7 +97,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
CashExitFlow::class.java to setOf(Amount::class.java, PartyAndReference::class.java),
CashIssueFlow::class.java to setOf(Amount::class.java, OpaqueBytes::class.java, Party::class.java),
CashPaymentFlow::class.java to setOf(Amount::class.java, Party::class.java),
FinalityFlow::class.java to emptySet(),
FinalityFlow::class.java to setOf(LinkedHashSet::class.java),
ContractUpgradeFlow::class.java to emptySet()
)
}

View File

@ -21,11 +21,11 @@ import net.corda.core.node.services.vault.Sort
import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.SignedTransaction
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.getRpcContext
import net.corda.node.services.messaging.requirePermission
import net.corda.node.services.startFlowPermission
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.node.utilities.transaction
import net.corda.nodeapi.CURRENT_RPC_USER
import org.bouncycastle.asn1.x500.X500Name
import org.jetbrains.exposed.sql.Database
import rx.Observable
@ -121,8 +121,9 @@ class CordaRPCOpsImpl(
// TODO: Check that this flow is annotated as being intended for RPC invocation
override fun <T : Any> startTrackedFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowProgressHandle<T> {
requirePermission(startFlowPermission(logicType))
val currentUser = FlowInitiator.RPC(CURRENT_RPC_USER.get().username)
val rpcContext = getRpcContext()
rpcContext.requirePermission(startFlowPermission(logicType))
val currentUser = FlowInitiator.RPC(rpcContext.currentUser.username)
val stateMachine = services.invokeFlowAsync(logicType, currentUser, *args)
return FlowProgressHandleImpl(
id = stateMachine.id,
@ -133,8 +134,9 @@ class CordaRPCOpsImpl(
// TODO: Check that this flow is annotated as being intended for RPC invocation
override fun <T : Any> startFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowHandle<T> {
requirePermission(startFlowPermission(logicType))
val currentUser = FlowInitiator.RPC(CURRENT_RPC_USER.get().username)
val rpcContext = getRpcContext()
rpcContext.requirePermission(startFlowPermission(logicType))
val currentUser = FlowInitiator.RPC(rpcContext.currentUser.username)
val stateMachine = services.invokeFlowAsync(logicType, currentUser, *args)
return FlowHandleImpl(id = stateMachine.id, returnValue = stateMachine.resultFuture)
}

View File

@ -256,7 +256,7 @@ class Node(override val configuration: FullNodeConfiguration,
/** Starts a blocking event loop for message dispatch. */
fun run() {
(net as NodeMessagingClient).run()
(net as NodeMessagingClient).run(messageBroker!!.serverControl)
}
// TODO: Do we really need setup?

View File

@ -21,10 +21,10 @@ import net.corda.node.services.messaging.NodeLoginModule.Companion.PEER_ROLE
import net.corda.node.services.messaging.NodeLoginModule.Companion.RPC_ROLE
import net.corda.node.services.messaging.NodeLoginModule.Companion.VERIFIER_ROLE
import net.corda.nodeapi.*
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.CLIENTS_PREFIX
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NODE_USER
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.PEER_USER
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl
import org.apache.activemq.artemis.core.config.BridgeConfiguration
import org.apache.activemq.artemis.core.config.Configuration
import org.apache.activemq.artemis.core.config.CoreQueueConfiguration
@ -37,6 +37,8 @@ import org.apache.activemq.artemis.core.remoting.impl.netty.NettyConnectorFactor
import org.apache.activemq.artemis.core.security.Role
import org.apache.activemq.artemis.core.server.ActiveMQServer
import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl
import org.apache.activemq.artemis.core.settings.impl.AddressFullMessagePolicy
import org.apache.activemq.artemis.core.settings.impl.AddressSettings
import org.apache.activemq.artemis.spi.core.remoting.*
import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager
import org.apache.activemq.artemis.spi.core.security.jaas.CertificateCallback
@ -97,6 +99,7 @@ class ArtemisMessagingServer(override val config: NodeConfiguration,
private val mutex = ThreadBox(InnerState())
private lateinit var activeMQServer: ActiveMQServer
val serverControl: ActiveMQServerControl get() = activeMQServer.activeMQServerControl
private val _networkMapConnectionFuture = config.networkMapService?.let { SettableFuture.create<Unit>() }
/**
* A [ListenableFuture] which completes when the server successfully connects to the network map node. If a
@ -185,10 +188,19 @@ class ArtemisMessagingServer(override val config: NodeConfiguration,
// Create an RPC queue: this will service locally connected clients only (not via a bridge) and those
// clients must have authenticated. We could use a single consumer for everything and perhaps we should,
// but these queues are not worth persisting.
queueConfig(RPC_REQUESTS_QUEUE, durable = false),
// The custom name for the queue is intentional - we may wish other things to subscribe to the
// NOTIFICATIONS_ADDRESS with different filters in future
queueConfig(RPC_QUEUE_REMOVALS_QUEUE, address = NOTIFICATIONS_ADDRESS, filter = "_AMQ_NotifType = 1", durable = false)
queueConfig(RPCApi.RPC_SERVER_QUEUE_NAME, durable = false),
queueConfig(
name = RPCApi.RPC_CLIENT_BINDING_REMOVALS,
address = NOTIFICATIONS_ADDRESS,
filter = RPCApi.RPC_CLIENT_BINDING_REMOVAL_FILTER_EXPRESSION,
durable = false
)
)
addressesSettings = mapOf(
"${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.#" to AddressSettings().apply {
maxSizeBytes = 10L * MAX_FILE_SIZE
addressFullMessagePolicy = AddressFullMessagePolicy.FAIL
}
)
configureAddressSecurity()
}
@ -213,16 +225,16 @@ class ArtemisMessagingServer(override val config: NodeConfiguration,
val nodeInternalRole = Role(NODE_ROLE, true, true, true, true, true, true, true, true)
securityRoles["$INTERNAL_PREFIX#"] = setOf(nodeInternalRole) // Do not add any other roles here as it's only for the node
securityRoles[P2P_QUEUE] = setOf(nodeInternalRole, restrictedRole(PEER_ROLE, send = true))
securityRoles[RPC_REQUESTS_QUEUE] = setOf(nodeInternalRole, restrictedRole(RPC_ROLE, send = true))
securityRoles[RPCApi.RPC_SERVER_QUEUE_NAME] = setOf(nodeInternalRole, restrictedRole(RPC_ROLE, send = true))
// TODO remove the NODE_USER role once the webserver doesn't need it
securityRoles["$CLIENTS_PREFIX$NODE_USER.rpc.*"] = setOf(nodeInternalRole)
securityRoles["${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$NODE_USER.#"] = setOf(nodeInternalRole)
for ((username) in userService.users) {
securityRoles["$CLIENTS_PREFIX$username.rpc.*"] = setOf(
securityRoles["${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$username.#"] = setOf(
nodeInternalRole,
restrictedRole("$CLIENTS_PREFIX$username", consume = true, createNonDurableQueue = true, deleteNonDurableQueue = true))
restrictedRole("${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$username", consume = true, createNonDurableQueue = true, deleteNonDurableQueue = true))
}
securityRoles[VerifierApi.VERIFICATION_REQUESTS_QUEUE_NAME] = setOf(nodeInternalRole, restrictedRole(VERIFIER_ROLE, consume = true))
securityRoles["${VerifierApi.VERIFICATION_RESPONSES_QUEUE_NAME_PREFIX}.*"] = setOf(nodeInternalRole, restrictedRole(VERIFIER_ROLE, send = true))
securityRoles["${VerifierApi.VERIFICATION_RESPONSES_QUEUE_NAME_PREFIX}.#"] = setOf(nodeInternalRole, restrictedRole(VERIFIER_ROLE, send = true))
}
private fun restrictedRole(name: String, send: Boolean = false, consume: Boolean = false, createDurableQueue: Boolean = false,
@ -629,7 +641,7 @@ class NodeLoginModule : LoginModule {
throw FailedLoginException("Password for user $username does not match")
}
principals += RolePrincipal(RPC_ROLE) // This enables the RPC client to send requests
principals += RolePrincipal("$CLIENTS_PREFIX$username") // This enables the RPC client to receive responses
principals += RolePrincipal("${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$username") // This enables the RPC client to receive responses
return username
}

View File

@ -11,7 +11,6 @@ import net.corda.core.node.VersionInfo
import net.corda.core.node.services.PartyInfo
import net.corda.core.node.services.TransactionVerifierService
import net.corda.core.random63BitValue
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.opaque
import net.corda.core.success
import net.corda.core.transactions.LedgerTransaction
@ -25,10 +24,7 @@ import net.corda.node.services.statemachine.StateMachineManager
import net.corda.node.services.transactions.InMemoryTransactionVerifierService
import net.corda.node.services.transactions.OutOfProcessTransactionVerifierService
import net.corda.node.utilities.*
import net.corda.nodeapi.ArtemisMessagingComponent
import net.corda.nodeapi.ArtemisTcpTransport
import net.corda.nodeapi.ConnectionDirection
import net.corda.nodeapi.VerifierApi
import net.corda.nodeapi.*
import net.corda.nodeapi.VerifierApi.VERIFICATION_REQUESTS_QUEUE_NAME
import net.corda.nodeapi.VerifierApi.VERIFICATION_RESPONSES_QUEUE_NAME_PREFIX
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
@ -36,6 +32,7 @@ import org.apache.activemq.artemis.api.core.Message.*
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.*
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl
import org.bouncycastle.asn1.x500.X500Name
import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.ResultRow
@ -71,7 +68,7 @@ class NodeMessagingClient(override val config: NodeConfiguration,
val versionInfo: VersionInfo,
val serverHostPort: HostAndPort,
val myIdentity: PublicKey?,
val nodeExecutor: AffinityExecutor,
val nodeExecutor: AffinityExecutor.ServiceAffinityExecutor,
val database: Database,
val networkMapRegistrationFuture: ListenableFuture<Unit>,
val monitoringService: MonitoringService
@ -100,11 +97,9 @@ class NodeMessagingClient(override val config: NodeConfiguration,
var producer: ClientProducer? = null
var p2pConsumer: ClientConsumer? = null
var session: ClientSession? = null
var clientFactory: ClientSessionFactory? = null
var rpcDispatcher: RPCDispatcher? = null
var sessionFactory: ClientSessionFactory? = null
var rpcServer: RPCServer? = null
// Consumer for inbound client RPC messages.
var rpcConsumer: ClientConsumer? = null
var rpcNotificationConsumer: ClientConsumer? = null
var verificationResponseConsumer: ClientConsumer? = null
}
@ -163,18 +158,19 @@ class NodeMessagingClient(override val config: NodeConfiguration,
val tcpTransport = ArtemisTcpTransport.tcpTransport(ConnectionDirection.Outbound(), serverHostPort, config)
val locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport)
locator.minLargeMessageSize = ArtemisMessagingServer.MAX_FILE_SIZE
clientFactory = locator.createSessionFactory()
sessionFactory = locator.createSessionFactory()
// Login using the node username. The broker will authentiate us as its node (as opposed to another peer)
// using our TLS certificate.
// Note that the acknowledgement of messages is not flushed to the Artermis journal until the default buffer
// size of 1MB is acknowledged.
val session = clientFactory!!.createSession(NODE_USER, NODE_USER, false, true, true, locator.isPreAcknowledge, DEFAULT_ACK_BATCH_SIZE)
val session = sessionFactory!!.createSession(NODE_USER, NODE_USER, false, true, true, locator.isPreAcknowledge, DEFAULT_ACK_BATCH_SIZE)
this.session = session
session.start()
// Create a general purpose producer.
producer = session.createProducer()
val producer = session.createProducer()
this.producer = producer
// Create a queue, consumer and producer for handling P2P network messages.
p2pConsumer = makeP2PConsumer(session, true)
@ -190,9 +186,7 @@ class NodeMessagingClient(override val config: NodeConfiguration,
}
}
rpcConsumer = session.createConsumer(RPC_REQUESTS_QUEUE)
rpcNotificationConsumer = session.createConsumer(RPC_QUEUE_REMOVALS_QUEUE)
rpcDispatcher = createRPCDispatcher(rpcOps, userService, config.myLegalName)
rpcServer = RPCServer(rpcOps, NODE_USER, NODE_USER, locator, userService, config.myLegalName)
fun checkVerifierCount() {
if (session.queueQuery(SimpleString(VERIFICATION_REQUESTS_QUEUE_NAME)).consumerCount == 0) {
@ -269,12 +263,12 @@ class NodeMessagingClient(override val config: NodeConfiguration,
return true
}
private fun runPreNetworkMap() {
private fun runPreNetworkMap(serverControl: ActiveMQServerControl) {
val consumer = state.locked {
check(started) { "start must be called first" }
check(!running) { "run can't be called twice" }
running = true
rpcDispatcher!!.start(rpcConsumer!!, rpcNotificationConsumer!!, nodeExecutor)
rpcServer!!.start(serverControl)
(verifierService as? OutOfProcessTransactionVerifierService)?.start(verificationResponseConsumer!!)
p2pConsumer!!
}
@ -300,9 +294,9 @@ class NodeMessagingClient(override val config: NodeConfiguration,
* we get our network map fetch response. At that point the filtering consumer is closed and we proceed to the second loop and
* consume all messages via a new consumer without a filter applied.
*/
fun run() {
fun run(serverControl: ActiveMQServerControl) {
// Build the network map.
runPreNetworkMap()
runPreNetworkMap(serverControl)
// Process everything else once we have the network map.
runPostNetworkMap()
shutdownLatch.countDown()
@ -404,17 +398,13 @@ class NodeMessagingClient(override val config: NodeConfiguration,
// Only first caller to gets running true to protect against double stop, which seems to happen in some integration tests.
if (running) {
state.locked {
rpcConsumer?.close()
rpcConsumer = null
rpcNotificationConsumer?.close()
rpcNotificationConsumer = null
producer?.close()
producer = null
// Ensure any trailing messages are committed to the journal
session!!.commit()
// Closing the factory closes all the sessions it produced as well.
clientFactory!!.close()
clientFactory = null
sessionFactory!!.close()
sessionFactory = null
}
}
}
@ -547,22 +537,6 @@ class NodeMessagingClient(override val config: NodeConfiguration,
}
}
private fun createRPCDispatcher(ops: RPCOps, userService: RPCUserService, nodeLegalName: X500Name): RPCDispatcher =
object : RPCDispatcher(ops, userService, nodeLegalName) {
override fun send(data: SerializedBytes<*>, toAddress: String) {
messagingExecutor.fetchFrom {
state.locked {
val msg = session!!.createMessage(false).apply {
writeBodyBufferBytes(data.bytes)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
}
producer!!.send(toAddress, msg)
}
}
}
}
private fun createOutOfProcessVerifierService(): TransactionVerifierService {
return object : OutOfProcessTransactionVerifierService(monitoringService) {
override fun sendRequest(nonce: Long, transaction: LedgerTransaction) {

View File

@ -1,219 +0,0 @@
package net.corda.node.services.messaging
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.annotations.VisibleForTesting
import com.google.common.collect.HashMultimap
import net.corda.core.ErrorOr
import net.corda.core.crypto.commonName
import net.corda.core.messaging.RPCOps
import net.corda.core.messaging.RPCReturnsObservables
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.debug
import net.corda.node.services.RPCUserService
import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.*
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NODE_USER
import org.apache.activemq.artemis.api.core.Message
import org.apache.activemq.artemis.api.core.client.ClientConsumer
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.bouncycastle.asn1.x500.X500Name
import rx.Notification
import rx.Observable
import rx.Subscription
import java.lang.reflect.InvocationTargetException
import java.util.concurrent.atomic.AtomicInteger
/**
* Intended to service transient clients only (not p2p nodes) for short-lived, transient request/response pairs.
* If you need robustness, this is the wrong system. If you don't want a response, this is probably the
* wrong system (you could just send a message). If you want complex customisation of how requests/responses
* are handled, this is probably the wrong system.
*/
// TODO remove the nodeLegalName parameter once the webserver doesn't need special privileges
abstract class RPCDispatcher(val ops: RPCOps, val userService: RPCUserService, val nodeLegalName: X500Name) {
// Throw an exception if there are overloaded methods
private val methodTable = ops.javaClass.declaredMethods.groupBy { it.name }.mapValues { it.value.single() }
private val queueToSubscription = HashMultimap.create<String, Subscription>()
private val handleCounter = AtomicInteger()
// Created afresh for every RPC that is annotated as returning observables. Every time an observable is
// encountered either in the RPC response or in an object graph that is being emitted by one of those
// observables, the handle counter is incremented and the server-side observable is subscribed to. The
// materialized observations are then sent to the queue the client created where they can be picked up.
//
// When the observables are deserialised on the client side, the handle is read from the byte stream and
// the queue is filtered to extract just those observations.
class ObservableSerializer : Serializer<Observable<Any>>() {
private fun toQName(kryo: Kryo): String = kryo.context[RPCKryoQNameKey] as String
private fun toDispatcher(kryo: Kryo): RPCDispatcher = kryo.context[RPCKryoDispatcherKey] as RPCDispatcher
override fun read(kryo: Kryo, input: Input, type: Class<Observable<Any>>): Observable<Any> {
throw UnsupportedOperationException("not implemented")
}
override fun write(kryo: Kryo, output: Output, obj: Observable<Any>) {
val qName = toQName(kryo)
val dispatcher = toDispatcher(kryo)
val handle = dispatcher.handleCounter.andIncrement
output.writeInt(handle, true)
// Observables can do three kinds of callback: "next" with a content object, "completed" and "error".
// Materializing the observable converts these three kinds of callback into a single stream of objects
// representing what happened, which is useful for us to send over the wire.
val subscription = obj.materialize().subscribe { materialised: Notification<out Any> ->
val newKryo = createRPCKryoForSerialization(qName, dispatcher)
val bits = try {
MarshalledObservation(handle, materialised).serialize(newKryo)
} finally {
releaseRPCKryoForSerialization(newKryo)
}
rpcLog.debug("RPC sending observation: $materialised")
dispatcher.send(bits, qName)
}
synchronized(dispatcher.queueToSubscription) {
dispatcher.queueToSubscription.put(qName, subscription)
}
}
}
fun dispatch(msg: ClientRPCRequestMessage) {
val (argsBytes, replyTo, observationsTo, methodName) = msg
val response: ErrorOr<Any> = ErrorOr.catch {
val method = methodTable[methodName] ?: throw RPCException("Received RPC for unknown method $methodName - possible client/server version skew?")
if (method.isAnnotationPresent(RPCReturnsObservables::class.java) && observationsTo == null)
throw RPCException("Received RPC without any destination for observations, but the RPC returns observables")
val kryo = createRPCKryoForSerialization(observationsTo, this)
val args = try {
argsBytes.deserialize(kryo)
} finally {
releaseRPCKryoForSerialization(kryo)
}
rpcLog.debug { "-> RPC -> $methodName(${args.joinToString()}) [reply to $replyTo]" }
try {
method.invoke(ops, *args)
} catch (e: InvocationTargetException) {
throw e.cause!!
}
}
rpcLog.debug { "<- RPC <- $methodName = $response " }
// Serialise, or send back a simple serialised ErrorOr structure if we couldn't do it.
val kryo = createRPCKryoForSerialization(observationsTo, this)
val responseBits = try {
response.serialize(kryo)
} catch (e: KryoException) {
rpcLog.error("Failed to respond to inbound RPC $methodName", e)
ErrorOr.of(e).serialize(kryo)
} finally {
releaseRPCKryoForSerialization(kryo)
}
send(responseBits, replyTo)
}
abstract fun send(data: SerializedBytes<*>, toAddress: String)
fun start(rpcConsumer: ClientConsumer, rpcNotificationConsumer: ClientConsumer?, onExecutor: AffinityExecutor) {
rpcNotificationConsumer?.setMessageHandler { msg ->
val qName = msg.getStringProperty("_AMQ_RoutingName")
val subscriptions = synchronized(queueToSubscription) {
queueToSubscription.removeAll(qName)
}
if (subscriptions.isNotEmpty()) {
rpcLog.debug("Observable queue was deleted, unsubscribing: $qName")
subscriptions.forEach { it.unsubscribe() }
}
}
rpcConsumer.setMessageHandler { msg ->
msg.acknowledge()
// All RPCs run on the main server thread, in order to avoid running concurrently with
// potentially state changing requests from other nodes and each other. If we need to
// give better latency to client RPCs in future we could use an executor that supports
// job priorities.
onExecutor.execute {
try {
val rpcMessage = msg.toRPCRequestMessage()
CURRENT_RPC_USER.set(rpcMessage.user)
dispatch(rpcMessage)
} catch(e: RPCException) {
rpcLog.warn("Received malformed client RPC message: ${e.message}")
rpcLog.trace("RPC exception", e)
} catch(e: Throwable) {
rpcLog.error("Uncaught exception when dispatching client RPC", e)
} finally {
CURRENT_RPC_USER.remove()
}
}
}
}
private fun ClientMessage.requiredString(name: String): String {
return getStringProperty(name) ?: throw RPCException("missing $name property")
}
/** Convert an Artemis [ClientMessage] to a MQ-neutral [ClientRPCRequestMessage]. */
private fun ClientMessage.toRPCRequestMessage(): ClientRPCRequestMessage {
val user = getUser(this)
val replyTo = getReturnAddress(user, ClientRPCRequestMessage.REPLY_TO, true)!!
val observationsTo = getReturnAddress(user, ClientRPCRequestMessage.OBSERVATIONS_TO, false)
val argBytes = ByteArray(bodySize).apply { bodyBuffer.readBytes(this) }
if (argBytes.isEmpty()) {
throw RPCException("empty serialized args")
}
val methodName = requiredString(ClientRPCRequestMessage.METHOD_NAME)
return ClientRPCRequestMessage(SerializedBytes(argBytes), replyTo, observationsTo, methodName, user)
}
// TODO remove this User once webserver doesn't need it
private val nodeUser = User(NODE_USER, NODE_USER, setOf())
@VisibleForTesting
protected open fun getUser(message: ClientMessage): User {
val validatedUser = message.requiredString(Message.HDR_VALIDATED_USER.toString())
val rpcUser = userService.getUser(validatedUser)
if (rpcUser != null) {
return rpcUser
} else {
try {
if (X500Name(validatedUser) == nodeLegalName) {
return nodeUser
}
} catch (ex: IllegalArgumentException) {
// Just means the two can't be compared, treat as no match
}
throw IllegalArgumentException("Validated user '$validatedUser' is not an RPC user nor the NODE user")
}
}
private fun ClientMessage.getReturnAddress(user: User, property: String, required: Boolean): String? {
return if (containsProperty(property)) {
"${ArtemisMessagingComponent.CLIENTS_PREFIX}${user.username}.rpc.${getLongProperty(property)}"
} else {
if (required) throw RPCException("missing $property property") else null
}
}
}
private val rpcSerKryoPool = KryoPool.Builder { RPCKryo(RPCDispatcher.ObservableSerializer()) }.build()
fun createRPCKryoForSerialization(qName: String? = null, dispatcher: RPCDispatcher? = null): Kryo {
val kryo = rpcSerKryoPool.borrow()
kryo.context.put(RPCKryoQNameKey, qName)
kryo.context.put(RPCKryoDispatcherKey, dispatcher)
return kryo
}
fun releaseRPCKryoForSerialization(kryo: Kryo) {
rpcSerKryoPool.release(kryo)
}

View File

@ -0,0 +1,373 @@
package net.corda.node.services.messaging
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.cache.Cache
import com.google.common.cache.CacheBuilder
import com.google.common.cache.RemovalListener
import com.google.common.collect.HashMultimap
import com.google.common.collect.Multimaps
import com.google.common.collect.SetMultimap
import com.google.common.util.concurrent.ThreadFactoryBuilder
import net.corda.core.ErrorOr
import net.corda.core.crypto.commonName
import net.corda.core.messaging.RPCOps
import net.corda.core.random63BitValue
import net.corda.core.seconds
import net.corda.core.serialization.KryoPoolWithContext
import net.corda.core.utilities.LazyStickyPool
import net.corda.core.utilities.LifeCycle
import net.corda.core.utilities.debug
import net.corda.core.utilities.loggerFor
import net.corda.node.services.RPCUserService
import net.corda.nodeapi.*
import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NODE_USER
import org.apache.activemq.artemis.api.core.Message
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
import org.apache.activemq.artemis.api.core.client.ClientConsumer
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.apache.activemq.artemis.api.core.client.ServerLocator
import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl
import org.apache.activemq.artemis.api.core.management.CoreNotificationType
import org.apache.activemq.artemis.api.core.management.ManagementHelper
import org.bouncycastle.asn1.x500.X500Name
import rx.Notification
import rx.Observable
import rx.Subscriber
import rx.Subscription
import java.lang.reflect.InvocationTargetException
import java.time.Duration
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
data class RPCServerConfiguration(
/** The number of threads to use for handling RPC requests */
val rpcThreadPoolSize: Int,
/** The number of consumers to handle incoming messages */
val consumerPoolSize: Int,
/** The maximum number of producers to create to handle outgoing messages */
val producerPoolBound: Int,
/** The interval of subscription reaping */
val reapInterval: Duration
) {
companion object {
val default = RPCServerConfiguration(
rpcThreadPoolSize = 4,
consumerPoolSize = 2,
producerPoolBound = 4,
reapInterval = 1.seconds
)
}
}
/**
* The [RPCServer] implements the complement of [RPCClient]. When an RPC request arrives it dispatches to the
* corresponding function in [ops]. During serialisation of the reply (and later observations) the server subscribes to
* each Observable it encounters and captures the client address to associate with these Observables. Later it uses this
* address to forward observations arriving on the Observables.
*
* The way this is done is similar to that in [RPCClient], we use Kryo and add a context to stores the subscription map.
*/
class RPCServer(
private val ops: RPCOps,
private val rpcServerUsername: String,
private val rpcServerPassword: String,
private val serverLocator: ServerLocator,
private val userService: RPCUserService,
private val nodeLegalName: X500Name,
private val rpcConfiguration: RPCServerConfiguration = RPCServerConfiguration.default
) {
private companion object {
val log = loggerFor<RPCServer>()
val kryoPool = KryoPool.Builder { RPCKryo(RpcServerObservableSerializer) }.build()
}
private enum class State {
UNSTARTED,
STARTED,
FINISHED
}
private val lifeCycle = LifeCycle(State.UNSTARTED)
// The methodname->Method map to use for dispatching.
private val methodTable = ops.javaClass.declaredMethods.groupBy { it.name }.mapValues { it.value.single() }
// The observable subscription mapping.
private val observableMap = createObservableSubscriptionMap()
// A mapping from client addresses to IDs of associated Observables
private val clientAddressToObservables = Multimaps.synchronizedSetMultimap(HashMultimap.create<SimpleString, RPCApi.ObservableId>())
// The scheduled reaper handle.
private lateinit var reaperScheduledFuture: ScheduledFuture<*>
private val observationSendExecutor = Executors.newFixedThreadPool(
1,
ThreadFactoryBuilder().setNameFormat("rpc-observation-sender-%d").build()
)
private val rpcExecutor = Executors.newScheduledThreadPool(
rpcConfiguration.rpcThreadPoolSize,
ThreadFactoryBuilder().setNameFormat("rpc-server-handler-pool-%d").build()
)
private val reaperExecutor = Executors.newScheduledThreadPool(
1,
ThreadFactoryBuilder().setNameFormat("rpc-server-reaper-%d").build()
)
private val sessionAndConsumers = ArrayList<ArtemisConsumer>(rpcConfiguration.consumerPoolSize)
private val sessionAndProducerPool = LazyStickyPool(rpcConfiguration.producerPoolBound) {
val sessionFactory = serverLocator.createSessionFactory()
val session = sessionFactory.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
session.start()
ArtemisProducer(sessionFactory, session, session.createProducer())
}
private lateinit var clientBindingRemovalConsumer: ClientConsumer
private lateinit var serverControl: ActiveMQServerControl
private fun createObservableSubscriptionMap(): ObservableSubscriptionMap {
val onObservableRemove = RemovalListener<RPCApi.ObservableId, ObservableSubscription> {
log.debug { "Unsubscribing from Observable with id ${it.key} because of ${it.cause}" }
it.value.subscription.unsubscribe()
}
return CacheBuilder.newBuilder().removalListener(onObservableRemove).build()
}
fun start(activeMqServerControl: ActiveMQServerControl) {
log.info("Starting RPC server with configuration $rpcConfiguration")
reaperScheduledFuture = reaperExecutor.scheduleAtFixedRate(
this::reapSubscriptions,
rpcConfiguration.reapInterval.toMillis(),
rpcConfiguration.reapInterval.toMillis(),
TimeUnit.MILLISECONDS
)
val sessions = ArrayList<ClientSession>()
for (i in 1 .. rpcConfiguration.consumerPoolSize) {
val sessionFactory = serverLocator.createSessionFactory()
val session = sessionFactory.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
val consumer = session.createConsumer(RPCApi.RPC_SERVER_QUEUE_NAME)
consumer.setMessageHandler(this@RPCServer::clientArtemisMessageHandler)
sessionAndConsumers.add(ArtemisConsumer(sessionFactory, session, consumer))
sessions.add(session)
}
clientBindingRemovalConsumer = sessionAndConsumers[0].session.createConsumer(RPCApi.RPC_CLIENT_BINDING_REMOVALS)
clientBindingRemovalConsumer.setMessageHandler(this::bindingRemovalArtemisMessageHandler)
serverControl = activeMqServerControl
lifeCycle.transition(State.UNSTARTED, State.STARTED)
// We delay the consumer session start because Artemis starts delivering messages immediately, so we need to be
// fully initialised.
sessions.forEach {
it.start()
}
}
fun close() {
reaperScheduledFuture.cancel(false)
rpcExecutor.shutdownNow()
reaperExecutor.shutdownNow()
rpcExecutor.awaitTermination(500, TimeUnit.MILLISECONDS)
reaperExecutor.awaitTermination(500, TimeUnit.MILLISECONDS)
sessionAndConsumers.forEach {
it.consumer.close()
it.session.close()
it.sessionFactory.close()
}
observableMap.invalidateAll()
reapSubscriptions()
sessionAndProducerPool.close().forEach {
it.producer.close()
it.session.close()
it.sessionFactory.close()
}
lifeCycle.transition(State.STARTED, State.FINISHED)
}
private fun bindingRemovalArtemisMessageHandler(artemisMessage: ClientMessage) {
lifeCycle.requireState(State.STARTED)
val notificationType = artemisMessage.getStringProperty(ManagementHelper.HDR_NOTIFICATION_TYPE)
require(notificationType == CoreNotificationType.BINDING_REMOVED.name)
val clientAddress = artemisMessage.getStringProperty(ManagementHelper.HDR_ROUTING_NAME)
log.warn("Detected RPC client disconnect on address $clientAddress, scheduling for reaping")
invalidateClient(SimpleString(clientAddress))
}
// Note that this function operates on the *current* view of client observables. During invalidation further
// Observables may be serialised and thus registered.
private fun invalidateClient(clientAddress: SimpleString) {
lifeCycle.requireState(State.STARTED)
val observableIds = clientAddressToObservables.removeAll(clientAddress)
observableMap.invalidateAll(observableIds)
}
private fun clientArtemisMessageHandler(artemisMessage: ClientMessage) {
lifeCycle.requireState(State.STARTED)
val clientToServer = RPCApi.ClientToServer.fromClientMessage(kryoPool, artemisMessage)
log.debug { "-> RPC -> $clientToServer" }
when (clientToServer) {
is RPCApi.ClientToServer.RpcRequest -> {
val rpcContext = RpcContext(
currentUser = getUser(artemisMessage)
)
rpcExecutor.submit {
val result = ErrorOr.catch {
try {
CURRENT_RPC_CONTEXT.set(rpcContext)
log.debug { "Calling ${clientToServer.methodName}" }
val method = methodTable[clientToServer.methodName] ?:
throw RPCException("Received RPC for unknown method ${clientToServer.methodName} - possible client/server version skew?")
method.invoke(ops, *clientToServer.arguments.toTypedArray())
} finally {
CURRENT_RPC_CONTEXT.remove()
}
}
val resultWithExceptionUnwrapped = result.mapError {
if (it is InvocationTargetException) {
it.cause ?: RPCException("Caught InvocationTargetException without cause")
} else {
it
}
}
val reply = RPCApi.ServerToClient.RpcReply(
id = clientToServer.id,
result = resultWithExceptionUnwrapped
)
val observableContext = ObservableContext(
clientToServer.id,
observableMap,
clientAddressToObservables,
clientToServer.clientAddress,
serverControl,
sessionAndProducerPool,
observationSendExecutor,
kryoPool
)
observableContext.sendMessage(reply)
}
}
is RPCApi.ClientToServer.ObservablesClosed -> {
observableMap.invalidateAll(clientToServer.ids)
}
}
artemisMessage.acknowledge()
}
private fun reapSubscriptions() {
lifeCycle.requireState(State.STARTED)
observableMap.cleanUp()
}
// TODO remove this User once webserver doesn't need it
private val nodeUser = User(NODE_USER, NODE_USER, setOf())
private fun getUser(message: ClientMessage): User {
val validatedUser = message.getStringProperty(Message.HDR_VALIDATED_USER) ?: throw IllegalArgumentException("Missing validated user from the Artemis message")
val rpcUser = userService.getUser(validatedUser)
if (rpcUser != null) {
return rpcUser
} else if (X500Name(validatedUser) == nodeLegalName) {
return nodeUser
} else {
throw IllegalArgumentException("Validated user '$validatedUser' is not an RPC user nor the NODE user")
}
}
}
@JvmField
internal val CURRENT_RPC_CONTEXT: ThreadLocal<RpcContext> = ThreadLocal()
/**
* Returns a context specific to the current RPC call. Note that trying to call this function outside of an RPC will
* throw. If you'd like to use the context outside of the call (e.g. in another thread) then pass the returned reference
* around explicitly.
*/
fun getRpcContext(): RpcContext = CURRENT_RPC_CONTEXT.get()
/**
* @param currentUser This is available to RPC implementations to query the validated [User] that is calling it. Each
* user has a set of permissions they're entitled to which can be used to control access.
*/
data class RpcContext(
val currentUser: User
)
class ObservableSubscription(
val subscription: Subscription
)
typealias ObservableSubscriptionMap = Cache<RPCApi.ObservableId, ObservableSubscription>
// We construct an observable context on each RPC request. If subsequently a nested Observable is
// encountered this same context is propagated by the instrumented KryoPool. This way all
// observations rooted in a single RPC will be muxed correctly. Note that the context construction
// itself is quite cheap.
class ObservableContext(
val rpcRequestId: RPCApi.RpcRequestId,
val observableMap: ObservableSubscriptionMap,
val clientAddressToObservables: SetMultimap<SimpleString, RPCApi.ObservableId>,
val clientAddress: SimpleString,
val serverControl: ActiveMQServerControl,
val sessionAndProducerPool: LazyStickyPool<ArtemisProducer>,
val observationSendExecutor: ExecutorService,
kryoPool: KryoPool
) {
private companion object {
val log = loggerFor<ObservableContext>()
}
private val kryoPoolWithObservableContext = RpcServerObservableSerializer.createPoolWithContext(kryoPool, this)
fun sendMessage(serverToClient: RPCApi.ServerToClient) {
try {
sessionAndProducerPool.run(rpcRequestId) {
val artemisMessage = it.session.createMessage(false)
serverToClient.writeToClientMessage(kryoPoolWithObservableContext, artemisMessage)
it.producer.send(clientAddress, artemisMessage)
log.debug("<- RPC <- $serverToClient")
}
} catch (throwable: Throwable) {
log.error("Failed to send message, kicking client. Message was $serverToClient", throwable)
serverControl.closeConsumerConnectionsForAddress(clientAddress.toString())
}
}
}
private object RpcServerObservableSerializer : Serializer<Observable<Any>>() {
private object RpcObservableContextKey
private val log = loggerFor<RpcServerObservableSerializer>()
fun createPoolWithContext(kryoPool: KryoPool, observableContext: ObservableContext): KryoPool {
return KryoPoolWithContext(kryoPool, RpcObservableContextKey, observableContext)
}
override fun read(kryo: Kryo?, input: Input?, type: Class<Observable<Any>>?): Observable<Any> {
throw UnsupportedOperationException()
}
override fun write(kryo: Kryo, output: Output, observable: Observable<Any>) {
val observableId = RPCApi.ObservableId(random63BitValue())
val observableContext = kryo.context[RpcObservableContextKey] as ObservableContext
output.writeLong(observableId.toLong, true)
val observableWithSubscription = ObservableSubscription(
// We capture [observableContext] in the subscriber. Note that all synchronisation/kryo borrowing
// must be done again within the subscriber
subscription = observable.materialize().subscribe(
object : Subscriber<Notification<Any>>() {
override fun onNext(observation: Notification<Any>) {
if (!isUnsubscribed) {
observableContext.observationSendExecutor.submit {
observableContext.sendMessage(RPCApi.ServerToClient.Observation(observableId, observation))
}
}
}
override fun onError(exception: Throwable) {
log.error("onError called in materialize()d RPC Observable", exception)
}
override fun onCompleted() {
}
}
)
)
observableContext.clientAddressToObservables.put(observableContext.clientAddress, observableId)
observableContext.observableMap.put(observableId, observableWithSubscription)
}
}

View File

@ -3,13 +3,11 @@
package net.corda.node.services.messaging
import net.corda.nodeapi.ArtemisMessagingComponent
import net.corda.nodeapi.CURRENT_RPC_USER
import net.corda.nodeapi.PermissionException
/** Helper method which checks that the current RPC user is entitled for the given permission. Throws a [PermissionException] otherwise. */
fun requirePermission(permission: String) {
fun RpcContext.requirePermission(permission: String) {
// TODO remove the NODE_USER condition once webserver doesn't need it
val currentUser = CURRENT_RPC_USER.get()
val currentUserPermissions = currentUser.permissions
if (currentUser.username != ArtemisMessagingComponent.NODE_USER && currentUserPermissions.intersect(listOf(permission, "ALL")).isEmpty()) {
throw PermissionException("User not permissioned for $permission, permissions are $currentUserPermissions")

View File

@ -19,10 +19,11 @@ import net.corda.jackson.JacksonSupport
import net.corda.jackson.StringToMethodCallParser
import net.corda.node.internal.Node
import net.corda.node.printBasicNodeInfo
import net.corda.node.services.messaging.CURRENT_RPC_CONTEXT
import net.corda.node.services.messaging.RpcContext
import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.utilities.ANSIProgressRenderer
import net.corda.nodeapi.ArtemisMessagingComponent
import net.corda.nodeapi.CURRENT_RPC_USER
import net.corda.nodeapi.User
import org.crsh.command.InvocationContext
import org.crsh.console.jline.JLineProcessor
@ -120,7 +121,7 @@ object InteractiveShell {
InterruptHandler { jlineProcessor.interrupt() }.install()
thread(name = "Command line shell processor", isDaemon = true) {
// Give whoever has local shell access administrator access to the node.
CURRENT_RPC_USER.set(User(ArtemisMessagingComponent.NODE_USER, "", setOf()))
CURRENT_RPC_CONTEXT.set(RpcContext(User(ArtemisMessagingComponent.NODE_USER, "", setOf())))
Emoji.renderIfSupported {
jlineProcessor.run()
}

View File

@ -15,11 +15,12 @@ import net.corda.core.transactions.SignedTransaction
import net.corda.flows.CashIssueFlow
import net.corda.flows.CashPaymentFlow
import net.corda.node.internal.CordaRPCOpsImpl
import net.corda.node.services.messaging.CURRENT_RPC_CONTEXT
import net.corda.node.services.messaging.RpcContext
import net.corda.node.services.network.NetworkMapService
import net.corda.node.services.startFlowPermission
import net.corda.node.services.transactions.SimpleNotaryService
import net.corda.node.utilities.transaction
import net.corda.nodeapi.CURRENT_RPC_USER
import net.corda.nodeapi.PermissionException
import net.corda.nodeapi.User
import net.corda.testing.expect
@ -57,10 +58,10 @@ class CordaRPCOpsImplTest {
aliceNode = network.createNode(networkMapAddress = networkMap.info.address)
notaryNode = network.createNode(advertisedServices = ServiceInfo(SimpleNotaryService.type), networkMapAddress = networkMap.info.address)
rpc = CordaRPCOpsImpl(aliceNode.services, aliceNode.smm, aliceNode.database)
CURRENT_RPC_USER.set(User("user", "pwd", permissions = setOf(
CURRENT_RPC_CONTEXT.set(RpcContext(User("user", "pwd", permissions = setOf(
startFlowPermission<CashIssueFlow>(),
startFlowPermission<CashPaymentFlow>()
)))
))))
aliceNode.database.transaction {
stateMachineUpdates = rpc.stateMachinesAndUpdates().second
@ -194,7 +195,7 @@ class CordaRPCOpsImplTest {
@Test
fun `cash command by user not permissioned for cash`() {
CURRENT_RPC_USER.set(User("user", "pwd", permissions = emptySet()))
CURRENT_RPC_CONTEXT.set(RpcContext(User("user", "pwd", permissions = emptySet())))
assertThatExceptionOfType(PermissionException::class.java).isThrownBy {
rpc.startFlow(::CashIssueFlow,
Amount(100, USD),

View File

@ -214,7 +214,7 @@ class ArtemisMessagingTests {
receivedMessages.add(message)
}
// Run after the handlers are added, otherwise (some of) the messages get delivered and discarded / dead-lettered.
thread { messagingClient.run() }
thread { messagingClient.run(messagingServer!!.serverControl) }
return messagingClient
}

View File

@ -25,14 +25,14 @@ class AttachmentDemoTest {
).getOrThrow()
val senderThread = CompletableFuture.supplyAsync {
nodeA.rpcClientToNode().use(demoUser[0].username, demoUser[0].password) {
sender(this, numOfExpectedBytes)
nodeA.rpcClientToNode().start(demoUser[0].username, demoUser[0].password).use {
sender(it.proxy, numOfExpectedBytes)
}
}.exceptionally { it.printStackTrace() }
val recipientThread = CompletableFuture.supplyAsync {
nodeB.rpcClientToNode().use(demoUser[0].username, demoUser[0].password) {
recipient(this)
nodeB.rpcClientToNode().start(demoUser[0].username, demoUser[0].password).use {
recipient(it.proxy)
}
}.exceptionally { it.printStackTrace() }

View File

@ -18,10 +18,12 @@ import net.corda.core.utilities.DUMMY_NOTARY
import net.corda.core.utilities.DUMMY_NOTARY_KEY
import net.corda.core.utilities.Emoji
import net.corda.flows.FinalityFlow
import net.corda.node.driver.poll
import java.io.InputStream
import java.net.HttpURLConnection
import java.net.URL
import java.security.PublicKey
import java.util.concurrent.Executors
import java.util.jar.JarInputStream
import javax.servlet.http.HttpServletResponse.SC_OK
import javax.ws.rs.core.HttpHeaders.CONTENT_DISPOSITION
@ -50,15 +52,15 @@ fun main(args: Array<String>) {
Role.SENDER -> {
val host = HostAndPort.fromString("localhost:10006")
println("Connecting to sender node ($host)")
CordaRPCClient(host).use("demo", "demo") {
sender(this)
CordaRPCClient(host).start("demo", "demo").use {
sender(it.proxy)
}
}
Role.RECIPIENT -> {
val host = HostAndPort.fromString("localhost:10009")
println("Connecting to the recipient node ($host)")
CordaRPCClient(host).use("demo", "demo") {
recipient(this)
CordaRPCClient(host).start("demo", "demo").use {
recipient(it.proxy)
}
}
}
@ -72,7 +74,8 @@ fun sender(rpc: CordaRPCOps, numOfClearBytes: Int = 1024) { // default size 1K.
fun sender(rpc: CordaRPCOps, inputStream: InputStream, hash: SecureHash.SHA256) {
// Get the identity key of the other side (the recipient).
val otherSide: Party = rpc.partyFromX500Name(DUMMY_BANK_B.name) ?: throw IllegalStateException("Could not find counterparty \"${DUMMY_BANK_B.name}\"")
val executor = Executors.newScheduledThreadPool(1)
val otherSide: Party = poll(executor, DUMMY_BANK_B.name.toString()) { rpc.partyFromX500Name(DUMMY_BANK_B.name) }.get()
// Make sure we have the file in storage
if (!rpc.attachmentExists(hash)) {
@ -97,6 +100,7 @@ fun sender(rpc: CordaRPCOps, inputStream: InputStream, hash: SecureHash.SHA256)
val flowHandle = rpc.startTrackedFlow(::FinalityFlow, stx, setOf(otherSide))
flowHandle.progress.subscribe(::println)
flowHandle.returnValue.getOrThrow()
println("Sent ${stx.id}")
}
fun recipient(rpc: CordaRPCOps) {

View File

@ -26,13 +26,11 @@ class BankOfCordaRPCClientTest {
// Bank of Corda RPC Client
val bocClient = nodeBankOfCorda.rpcClientToNode()
bocClient.start("bocManager", "password1")
val bocProxy = bocClient.proxy()
val bocProxy = bocClient.start("bocManager", "password1").proxy
// Big Corporation RPC Client
val bigCorpClient = nodeBigCorporation.rpcClientToNode()
bigCorpClient.start("bigCorpCFO", "password2")
val bigCorpProxy = bigCorpClient.proxy()
val bigCorpProxy = bigCorpClient.start("bigCorpCFO", "password2").proxy
// Register for Bank of Corda Vault updates
val vaultUpdatesBoc = bocProxy.vaultAndUpdates().second

View File

@ -32,18 +32,19 @@ class BankOfCordaClientApi(val hostAndPort: HostAndPort) {
fun requestRPCIssue(params: IssueRequestParams): SignedTransaction {
val client = CordaRPCClient(hostAndPort)
// TODO: privileged security controls required
client.start("bankUser", "test")
val proxy = client.proxy()
client.start("bankUser", "test").use { connection ->
val proxy = connection.proxy
// Resolve parties via RPC
val issueToParty = proxy.partyFromX500Name(params.issueToPartyName)
?: throw Exception("Unable to locate ${params.issueToPartyName} in Network Map Service")
val issuerBankParty = proxy.partyFromX500Name(params.issuerBankName)
?: throw Exception("Unable to locate ${params.issuerBankName} in Network Map Service")
// Resolve parties via RPC
val issueToParty = proxy.partyFromX500Name(params.issueToPartyName)
?: throw Exception("Unable to locate ${params.issueToPartyName} in Network Map Service")
val issuerBankParty = proxy.partyFromX500Name(params.issuerBankName)
?: throw Exception("Unable to locate ${params.issuerBankName} in Network Map Service")
val amount = Amount(params.amount, currency(params.currency))
val issuerToPartyRef = OpaqueBytes.of(params.issueToPartyRefAsString.toByte())
val amount = Amount(params.amount, currency(params.currency))
val issuerToPartyRef = OpaqueBytes.of(params.issueToPartyRefAsString.toByte())
return proxy.startFlow(::IssuanceRequester, amount, issueToParty, issuerToPartyRef, issuerBankParty).returnValue.getOrThrow()
return proxy.startFlow(::IssuanceRequester, amount, issueToParty, issuerToPartyRef, issuerBankParty).returnValue.getOrThrow()
}
}
}

View File

@ -66,8 +66,7 @@ class IRSDemoTest : IntegrationTestCategory {
fun getFixingDateObservable(config: FullNodeConfiguration): BlockingObservable<LocalDate?> {
val client = CordaRPCClient(config.rpcAddress!!)
client.start("user", "password")
val proxy = client.proxy()
val proxy = client.start("user", "password").proxy
val vaultUpdates = proxy.vaultAndUpdates().second
val fixingDates = vaultUpdates.map { update ->

View File

@ -22,8 +22,8 @@ import kotlin.system.exitProcess
fun main(args: Array<String>) {
val host = HostAndPort.fromString("localhost:10003")
println("Connecting to the recipient node ($host)")
CordaRPCClient(host).use("demo", "demo") {
val api = NotaryDemoClientApi(this)
CordaRPCClient(host).start("demo", "demo").use {
val api = NotaryDemoClientApi(it.proxy)
api.startNotarisation()
}
}

View File

@ -4,11 +4,13 @@ import com.google.common.util.concurrent.Futures
import net.corda.client.rpc.CordaRPCClient
import net.corda.core.contracts.DOLLARS
import net.corda.core.getOrThrow
import net.corda.core.millis
import net.corda.core.node.services.ServiceInfo
import net.corda.core.utilities.DUMMY_BANK_A
import net.corda.core.utilities.DUMMY_BANK_B
import net.corda.core.utilities.DUMMY_NOTARY
import net.corda.flows.IssuerFlow
import net.corda.node.driver.poll
import net.corda.node.services.startFlowPermission
import net.corda.node.services.transactions.SimpleNotaryService
import net.corda.nodeapi.User
@ -17,6 +19,7 @@ import net.corda.testing.node.NodeBasedTest
import net.corda.traderdemo.flow.SellerFlow
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
import java.util.concurrent.Executors
class TraderDemoTest : NodeBasedTest() {
@Test
@ -35,7 +38,7 @@ class TraderDemoTest : NodeBasedTest() {
val (nodeARpc, nodeBRpc) = listOf(nodeA, nodeB).map {
val client = CordaRPCClient(it.configuration.rpcAddress!!)
client.start(demoUser[0].username, demoUser[0].password).proxy()
client.start(demoUser[0].username, demoUser[0].password).proxy
}
val clientA = TraderDemoClientApi(nodeARpc)
@ -48,10 +51,19 @@ class TraderDemoTest : NodeBasedTest() {
clientA.runBuyer(amount = 100.DOLLARS)
clientB.runSeller(counterparty = nodeA.info.legalIdentity.name, amount = 5.DOLLARS)
val actualPaper = listOf(clientA.commercialPaperCount, clientB.commercialPaperCount)
assertThat(clientA.cashCount).isGreaterThan(originalACash)
assertThat(clientB.cashCount).isEqualTo(expectedBCash)
assertThat(actualPaper).isEqualTo(expectedPaper)
// Wait until A receives the commercial paper
val executor = Executors.newScheduledThreadPool(1)
poll(executor, "A to be notified of the commercial paper", pollInterval = 100.millis) {
val actualPaper = listOf(clientA.commercialPaperCount, clientB.commercialPaperCount)
if (actualPaper == expectedPaper) {
Unit
} else {
null
}
}.get()
executor.shutdown()
assertThat(clientA.dollarCashBalance).isEqualTo(95.DOLLARS)
assertThat(clientB.dollarCashBalance).isEqualTo(5.DOLLARS)
}

View File

@ -45,13 +45,13 @@ private class TraderDemo {
val role = options.valueOf(roleArg)!!
if (role == Role.BUYER) {
val host = HostAndPort.fromString("localhost:10006")
CordaRPCClient(host).use("demo", "demo") {
TraderDemoClientApi(this).runBuyer()
CordaRPCClient(host).start("demo", "demo").use {
TraderDemoClientApi(it.proxy).runBuyer()
}
} else {
val host = HostAndPort.fromString("localhost:10009")
CordaRPCClient(host).use("demo", "demo") {
TraderDemoClientApi(this).runSeller(1000.DOLLARS, DUMMY_BANK_A.name)
TraderDemoClientApi(it.proxy).runSeller(1000.DOLLARS, DUMMY_BANK_A.name)
}
}
}

View File

@ -16,6 +16,7 @@ dependencies {
compile project(':node')
compile project(':webserver')
compile project(':verifier')
compile project(':client:mock')
compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version"
compile "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version"

View File

@ -35,6 +35,7 @@ import java.nio.file.Path
import java.security.KeyPair
import java.security.PublicKey
import java.util.*
import java.util.concurrent.atomic.AtomicInteger
import kotlin.reflect.KClass
/**
@ -90,6 +91,7 @@ val MOCK_VERSION_INFO = VersionInfo(1, "Mock release", "Mock revision", "Mock Ve
fun generateStateRef() = StateRef(SecureHash.randomSHA256(), 0)
private val freePortCounter = AtomicInteger(30000)
/**
* Returns a free port.
*
@ -97,7 +99,7 @@ fun generateStateRef() = StateRef(SecureHash.randomSHA256(), 0)
* Use [getFreeLocalPorts] for getting multiple ports.
*/
fun freeLocalHostAndPort(): HostAndPort {
val freePort = ServerSocket(0).use { it.localPort }
val freePort = freePortCounter.getAndAccumulate(0) { prev, _ -> 30000 + (prev - 30000 + 1) % 10000 }
return HostAndPort.fromParts("localhost", freePort)
}
@ -108,12 +110,8 @@ fun freeLocalHostAndPort(): HostAndPort {
* to the Node, some other process else could allocate the returned ports.
*/
fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List<HostAndPort> {
// Create a bunch of sockets up front.
val sockets = Array(numberToAlloc) { ServerSocket(0) }
val result = sockets.map { HostAndPort.fromParts(hostName, it.localPort) }
// Close sockets only once we've grabbed all the ports we need.
sockets.forEach(ServerSocket::close)
return result
val freePort = freePortCounter.getAndAccumulate(0) { prev, _ -> 30000 + (prev - 30000 + numberToAlloc) % 10000 }
return (freePort .. freePort + numberToAlloc - 1).map { HostAndPort.fromParts(hostName, it) }
}
/**

View File

@ -0,0 +1,53 @@
package net.corda.testing
import kotlin.reflect.KCallable
import kotlin.reflect.jvm.reflect
/**
* These functions may be used to run measurements of a function where the parameters are chosen from corresponding
* [Iterable]s in a lexical manner. An example use case would be benchmarking the speed of a certain function call using
* different combinations of parameters.
*/
@Suppress("UNCHECKED_CAST")
fun <A, R> measure(a: Iterable<A>, f: (A) -> R) =
measure(listOf(a), f.reflect()!!) { (f as ((Any?)->R))(it[0]) }
@Suppress("UNCHECKED_CAST")
fun <A, B, R> measure(a: Iterable<A>, b: Iterable<B>, f: (A, B) -> R) =
measure(listOf(a, b), f.reflect()!!) { (f as ((Any?,Any?)->R))(it[0], it[1]) }
@Suppress("UNCHECKED_CAST")
fun <A, B, C, R> measure(a: Iterable<A>, b: Iterable<B>, c: Iterable<C>, f: (A, B, C) -> R) =
measure(listOf(a, b, c), f.reflect()!!) { (f as ((Any?,Any?,Any?)->R))(it[0], it[1], it[2]) }
@Suppress("UNCHECKED_CAST")
fun <A, B, C, D, R> measure(a: Iterable<A>, b: Iterable<B>, c: Iterable<C>, d: Iterable<D>, f: (A, B, C, D) -> R) =
measure(listOf(a, b, c, d), f.reflect()!!) { (f as ((Any?,Any?,Any?,Any?)->R))(it[0], it[1], it[2], it[3]) }
private fun <R> measure(paramIterables: List<Iterable<Any?>>, kCallable: KCallable<R>, call: (Array<Any?>) -> R): Iterable<MeasureResult<R>> {
val kParameters = kCallable.parameters
return iterateLexical(paramIterables).map { params ->
MeasureResult(
parameters = params.mapIndexed { index, param -> Pair(kParameters[index].name!!, param) },
result = call(params.toTypedArray())
)
}
}
data class MeasureResult<out R>(
val parameters: List<Pair<String, Any?>>,
val result: R
)
fun <A> iterateLexical(iterables: List<Iterable<A>>): Iterable<List<A>> {
val result = ArrayList<List<A>>()
fun iterateLexicalHelper(index: Int, list: List<A>) {
if (index < iterables.size) {
iterables[index].forEach {
iterateLexicalHelper(index + 1, list + it)
}
} else {
result.add(list)
}
}
iterateLexicalHelper(0, emptyList())
return result
}

View File

@ -0,0 +1,474 @@
package net.corda.testing
import com.google.common.net.HostAndPort
import com.google.common.util.concurrent.ListenableFuture
import net.corda.client.mock.Generator
import net.corda.client.mock.generateOrFail
import net.corda.client.mock.int
import net.corda.client.mock.string
import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.div
import net.corda.core.messaging.RPCOps
import net.corda.core.random63BitValue
import net.corda.core.utilities.ProcessUtilities
import net.corda.node.driver.*
import net.corda.node.services.RPCUserService
import net.corda.node.services.messaging.ArtemisMessagingServer
import net.corda.node.services.messaging.RPCServer
import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.ArtemisTcpTransport
import net.corda.nodeapi.ConnectionDirection
import net.corda.nodeapi.RPCApi
import net.corda.nodeapi.User
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.api.core.client.ActiveMQClient
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl
import org.apache.activemq.artemis.core.config.Configuration
import org.apache.activemq.artemis.core.config.CoreQueueConfiguration
import org.apache.activemq.artemis.core.config.impl.ConfigurationImpl
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMAcceptorFactory
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnectorFactory
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptorFactory
import org.apache.activemq.artemis.core.security.CheckType
import org.apache.activemq.artemis.core.security.Role
import org.apache.activemq.artemis.core.server.embedded.EmbeddedActiveMQ
import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl
import org.apache.activemq.artemis.core.settings.impl.AddressFullMessagePolicy
import org.apache.activemq.artemis.core.settings.impl.AddressSettings
import org.apache.activemq.artemis.spi.core.protocol.RemotingConnection
import org.apache.activemq.artemis.spi.core.security.ActiveMQSecurityManager3
import org.bouncycastle.asn1.x500.X500Name
import java.lang.reflect.Method
import java.nio.file.Path
import java.nio.file.Paths
import java.util.*
import javax.security.cert.X509Certificate
interface RPCDriverExposedDSLInterface : DriverDSLExposedInterface {
/**
* Starts an In-VM RPC server. Note that only a single one may be started.
*
* @param rpcUser The single user who can access the server through RPC, and their permissions.
* @param nodeLegalName The legal name of the node to check against to authenticate a super user.
* @param configuration The RPC server configuration.
* @param ops The server-side implementation of the RPC interface.
*/
fun <I : RPCOps> startInVmRpcServer(
rpcUser: User = rpcTestUser,
nodeLegalName: X500Name = fakeNodeLegalName,
maxFileSize: Int = ArtemisMessagingServer.MAX_FILE_SIZE,
maxBufferedBytesPerClient: Long = 10L * ArtemisMessagingServer.MAX_FILE_SIZE,
configuration: RPCServerConfiguration = RPCServerConfiguration.default,
ops : I
): ListenableFuture<Unit>
/**
* Starts an In-VM RPC client.
*
* @param rpcOpsClass The [Class] of the RPC interface.
* @param username The username to authenticate with.
* @param password The password to authenticate with.
* @param configuration The RPC client configuration.
*/
fun <I : RPCOps> startInVmRpcClient(
rpcOpsClass: Class<I>,
username: String = rpcTestUser.username,
password: String = rpcTestUser.password,
configuration: RPCClientConfiguration = RPCClientConfiguration.default
): ListenableFuture<I>
/**
* Starts an In-VM Artemis session connecting to the RPC server.
*
* @param username The username to authenticate with.
* @param password The password to authenticate with.
*/
fun startInVmArtemisSession(
username: String = rpcTestUser.username,
password: String = rpcTestUser.password
): ClientSession
/**
* Starts a Netty RPC server.
*
* @param serverName The name of the server, to be used for the folder created for Artemis files.
* @param rpcUser The single user who can access the server through RPC, and their permissions.
* @param nodeLegalName The legal name of the node to check against to authenticate a super user.
* @param configuration The RPC server configuration.
* @param ops The server-side implementation of the RPC interface.
*/
fun <I : RPCOps> startRpcServer(
serverName: String = "driver-rpc-server-${random63BitValue()}",
rpcUser: User = rpcTestUser,
nodeLegalName: X500Name = fakeNodeLegalName,
maxFileSize: Int = ArtemisMessagingServer.MAX_FILE_SIZE,
maxBufferedBytesPerClient: Long = 10L * ArtemisMessagingServer.MAX_FILE_SIZE,
configuration: RPCServerConfiguration = RPCServerConfiguration.default,
ops : I
) : ListenableFuture<RpcServerHandle>
/**
* Starts a Netty RPC client.
*
* @param rpcOpsClass The [Class] of the RPC interface.
* @param rpcAddress The address of the RPC server to connect to.
* @param username The username to authenticate with.
* @param password The password to authenticate with.
* @param configuration The RPC client configuration.
*/
fun <I : RPCOps> startRpcClient(
rpcOpsClass: Class<I>,
rpcAddress: HostAndPort,
username: String = rpcTestUser.username,
password: String = rpcTestUser.password,
configuration: RPCClientConfiguration = RPCClientConfiguration.default
): ListenableFuture<I>
/**
* Starts a Netty RPC client in a new JVM process that calls random RPCs with random arguments.
*
* @param rpcOpsClass The [Class] of the RPC interface.
* @param rpcAddress The address of the RPC server to connect to.
* @param username The username to authenticate with.
* @param password The password to authenticate with.
*/
fun <I : RPCOps> startRandomRpcClient(
rpcOpsClass: Class<I>,
rpcAddress: HostAndPort,
username: String = rpcTestUser.username,
password: String = rpcTestUser.password
): ListenableFuture<Process>
/**
* Starts a Netty Artemis session connecting to an RPC server.
*
* @param rpcAddress The address of the RPC server.
* @param username The username to authenticate with.
* @param password The password to authenticate with.
*/
fun startArtemisSession(
rpcAddress: HostAndPort,
username: String = rpcTestUser.username,
password: String = rpcTestUser.password
): ClientSession
}
inline fun <reified I : RPCOps> RPCDriverExposedDSLInterface.startInVmRpcClient(
username: String = rpcTestUser.username,
password: String = rpcTestUser.password,
configuration: RPCClientConfiguration = RPCClientConfiguration.default
) = startInVmRpcClient(I::class.java, username, password, configuration)
inline fun <reified I : RPCOps> RPCDriverExposedDSLInterface.startRandomRpcClient(
hostAndPort: HostAndPort,
username: String = rpcTestUser.username,
password: String = rpcTestUser.password
) = startRandomRpcClient(I::class.java, hostAndPort, username, password)
inline fun <reified I : RPCOps> RPCDriverExposedDSLInterface.startRpcClient(
rpcAddress: HostAndPort,
username: String = rpcTestUser.username,
password: String = rpcTestUser.password,
configuration: RPCClientConfiguration = RPCClientConfiguration.default
) = startRpcClient(I::class.java, rpcAddress, username, password, configuration)
interface RPCDriverInternalDSLInterface : DriverDSLInternalInterface, RPCDriverExposedDSLInterface
data class RpcServerHandle(
val hostAndPort: HostAndPort,
val serverControl: ActiveMQServerControl
)
val rpcTestUser = User("user1", "test", permissions = emptySet())
val fakeNodeLegalName = X500Name("CN=not:a:valid:name")
// Use a global pool so that we can run RPC tests in parallel
private val globalPortAllocation = PortAllocation.Incremental(10000)
private val globalDebugPortAllocation = PortAllocation.Incremental(5005)
fun <A> rpcDriver(
isDebug: Boolean = false,
driverDirectory: Path = Paths.get("build", getTimestampAsDirectoryName()),
portAllocation: PortAllocation = globalPortAllocation,
debugPortAllocation: PortAllocation = globalDebugPortAllocation,
systemProperties: Map<String, String> = emptyMap(),
useTestClock: Boolean = false,
automaticallyStartNetworkMap: Boolean = false,
dsl: RPCDriverExposedDSLInterface.() -> A
) = genericDriver(
driverDsl = RPCDriverDSL(
DriverDSL(
portAllocation = portAllocation,
debugPortAllocation = debugPortAllocation,
systemProperties = systemProperties,
driverDirectory = driverDirectory.toAbsolutePath(),
useTestClock = useTestClock,
automaticallyStartNetworkMap = automaticallyStartNetworkMap,
isDebug = isDebug
)
),
coerce = { it },
dsl = dsl
)
private class SingleUserSecurityManager(val rpcUser: User) : ActiveMQSecurityManager3 {
override fun validateUser(user: String?, password: String?) = isValid(user, password)
override fun validateUserAndRole(user: String?, password: String?, roles: MutableSet<Role>?, checkType: CheckType?) = isValid(user, password)
override fun validateUser(user: String?, password: String?, certificates: Array<out X509Certificate>?): String? {
return validate(user, password)
}
override fun validateUserAndRole(user: String?, password: String?, roles: MutableSet<Role>?, checkType: CheckType?, address: String?, connection: RemotingConnection?): String? {
return validate(user, password)
}
private fun isValid(user: String?, password: String?): Boolean {
return rpcUser.username == user && rpcUser.password == password
}
private fun validate(user: String?, password: String?): String? {
return if (isValid(user, password)) user else null
}
}
data class RPCDriverDSL(
val driverDSL: DriverDSL
) : DriverDSLInternalInterface by driverDSL, RPCDriverInternalDSLInterface {
private companion object {
val notificationAddress = "notifications"
private fun ConfigurationImpl.configureCommonSettings(maxFileSize: Int, maxBufferedBytesPerClient: Long) {
managementNotificationAddress = SimpleString(notificationAddress)
isPopulateValidatedUser = true
journalBufferSize_NIO = maxFileSize
journalBufferSize_AIO = maxFileSize
journalFileSize = maxFileSize
queueConfigurations = listOf(
CoreQueueConfiguration().apply {
name = RPCApi.RPC_SERVER_QUEUE_NAME
address = RPCApi.RPC_SERVER_QUEUE_NAME
isDurable = false
},
CoreQueueConfiguration().apply {
name = RPCApi.RPC_CLIENT_BINDING_REMOVALS
address = notificationAddress
filterString = RPCApi.RPC_CLIENT_BINDING_REMOVAL_FILTER_EXPRESSION
isDurable = false
}
)
addressesSettings = mapOf(
"${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.#" to AddressSettings().apply {
maxSizeBytes = maxBufferedBytesPerClient
addressFullMessagePolicy = AddressFullMessagePolicy.FAIL
}
)
}
fun createInVmRpcServerArtemisConfig(maxFileSize: Int, maxBufferedBytesPerClient: Long): Configuration {
return ConfigurationImpl().apply {
acceptorConfigurations = setOf(TransportConfiguration(InVMAcceptorFactory::class.java.name))
isPersistenceEnabled = false
configureCommonSettings(maxFileSize, maxBufferedBytesPerClient)
}
}
fun createRpcServerArtemisConfig(maxFileSize: Int, maxBufferedBytesPerClient: Long, baseDirectory: Path, hostAndPort: HostAndPort): Configuration {
val connectionDirection = ConnectionDirection.Inbound(acceptorFactoryClassName = NettyAcceptorFactory::class.java.name)
return ConfigurationImpl().apply {
val artemisDir = "$baseDirectory/artemis"
bindingsDirectory = "$artemisDir/bindings"
journalDirectory = "$artemisDir/journal"
largeMessagesDirectory = "$artemisDir/large-messages"
acceptorConfigurations = setOf(ArtemisTcpTransport.tcpTransport(connectionDirection, hostAndPort, null))
configureCommonSettings(maxFileSize, maxBufferedBytesPerClient)
}
}
val inVmClientTransportConfiguration = TransportConfiguration(InVMConnectorFactory::class.java.name)
fun createNettyClientTransportConfiguration(hostAndPort: HostAndPort): TransportConfiguration {
return ArtemisTcpTransport.tcpTransport(ConnectionDirection.Outbound(), hostAndPort, null)
}
}
override fun <I : RPCOps> startInVmRpcServer(
rpcUser: User,
nodeLegalName: X500Name,
maxFileSize: Int,
maxBufferedBytesPerClient: Long,
configuration: RPCServerConfiguration,
ops: I
): ListenableFuture<Unit> {
return driverDSL.executorService.submit<Unit> {
val artemisConfig = createInVmRpcServerArtemisConfig(maxFileSize, maxBufferedBytesPerClient)
val server = EmbeddedActiveMQ()
server.setConfiguration(artemisConfig)
server.setSecurityManager(SingleUserSecurityManager(rpcUser))
server.start()
driverDSL.shutdownManager.registerShutdown {
server.activeMQServer.stop()
server.stop()
}
startRpcServerWithBrokerRunning(
rpcUser, nodeLegalName, configuration, ops, inVmClientTransportConfiguration,
server.activeMQServer.activeMQServerControl
)
}
}
override fun <I : RPCOps> startInVmRpcClient(rpcOpsClass: Class<I>, username: String, password: String, configuration: RPCClientConfiguration): ListenableFuture<I> {
return driverDSL.executorService.submit<I> {
val client = RPCClient<I>(inVmClientTransportConfiguration, configuration)
val connection = client.start(rpcOpsClass, username, password)
driverDSL.shutdownManager.registerShutdown {
connection.close()
}
connection.proxy
}
}
override fun startInVmArtemisSession(username: String, password: String): ClientSession {
val locator = ActiveMQClient.createServerLocatorWithoutHA(inVmClientTransportConfiguration)
val sessionFactory = locator.createSessionFactory()
val session = sessionFactory.createSession(username, password, false, true, true, locator.isPreAcknowledge, DEFAULT_ACK_BATCH_SIZE)
driverDSL.shutdownManager.registerShutdown {
session.close()
sessionFactory.close()
locator.close()
}
return session
}
override fun <I : RPCOps> startRpcServer(
serverName: String,
rpcUser: User,
nodeLegalName: X500Name,
maxFileSize: Int,
maxBufferedBytesPerClient: Long,
configuration: RPCServerConfiguration,
ops: I
): ListenableFuture<RpcServerHandle> {
val hostAndPort = driverDSL.portAllocation.nextHostAndPort()
addressMustNotBeBound(driverDSL.executorService, hostAndPort)
return driverDSL.executorService.submit<RpcServerHandle> {
val artemisConfig = createRpcServerArtemisConfig(maxFileSize, maxBufferedBytesPerClient, driverDSL.driverDirectory / serverName, hostAndPort)
val server = ActiveMQServerImpl(artemisConfig, SingleUserSecurityManager(rpcUser))
server.start()
driverDSL.shutdownManager.registerShutdown {
server.stop()
addressMustNotBeBound(driverDSL.executorService, hostAndPort).get()
}
val transportConfiguration = createNettyClientTransportConfiguration(hostAndPort)
startRpcServerWithBrokerRunning(
rpcUser, nodeLegalName, configuration, ops, transportConfiguration,
server.activeMQServerControl
)
RpcServerHandle(hostAndPort, server.activeMQServerControl)
}
}
override fun <I : RPCOps> startRpcClient(
rpcOpsClass: Class<I>,
rpcAddress: HostAndPort,
username: String,
password: String,
configuration: RPCClientConfiguration
): ListenableFuture<I> {
return driverDSL.executorService.submit<I> {
val client = RPCClient<I>(ArtemisTcpTransport.tcpTransport(ConnectionDirection.Outbound(), rpcAddress, null), configuration)
val connection = client.start(rpcOpsClass, username, password)
driverDSL.shutdownManager.registerShutdown {
connection.close()
}
connection.proxy
}
}
override fun <I : RPCOps> startRandomRpcClient(rpcOpsClass: Class<I>, rpcAddress: HostAndPort, username: String, password: String): ListenableFuture<Process> {
val processFuture = driverDSL.executorService.submit<Process> {
ProcessUtilities.startJavaProcess<RandomRpcUser>(listOf(rpcOpsClass.name, rpcAddress.toString(), username, password))
}
driverDSL.shutdownManager.registerProcessShutdown(processFuture)
return processFuture
}
override fun startArtemisSession(rpcAddress: HostAndPort, username: String, password: String): ClientSession {
val locator = ActiveMQClient.createServerLocatorWithoutHA(createNettyClientTransportConfiguration(rpcAddress))
val sessionFactory = locator.createSessionFactory()
val session = sessionFactory.createSession(username, password, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
driverDSL.shutdownManager.registerShutdown {
session.close()
sessionFactory.close()
locator.close()
}
return session
}
private fun <I : RPCOps> startRpcServerWithBrokerRunning(
rpcUser: User,
nodeLegalName: X500Name,
configuration: RPCServerConfiguration,
ops: I,
transportConfiguration: TransportConfiguration,
serverControl: ActiveMQServerControl
) {
val locator = ActiveMQClient.createServerLocatorWithoutHA(transportConfiguration).apply {
minLargeMessageSize = ArtemisMessagingServer.MAX_FILE_SIZE
}
val userService = object : RPCUserService {
override fun getUser(username: String): User? = if (username == rpcUser.username) rpcUser else null
override val users: List<User> get() = listOf(rpcUser)
}
val rpcServer = RPCServer(
ops,
rpcUser.username,
rpcUser.password,
locator,
userService,
nodeLegalName,
configuration
)
driverDSL.shutdownManager.registerShutdown {
rpcServer.close()
locator.close()
}
rpcServer.start(serverControl)
}
}
/**
* An out-of-process RPC user that connects to an RPC server and issues random RPCs with random arguments.
*/
class RandomRpcUser {
companion object {
private inline fun <reified T> HashMap<Class<*>, Generator<*>>.add(generator: Generator<T>) = this.putIfAbsent(T::class.java, generator)
val generatorStore = HashMap<Class<*>, Generator<*>>().apply {
add(Generator.string())
add(Generator.int())
}
data class Call(val method: Method, val call: () -> Any?)
@JvmStatic
fun main(args: Array<String>) {
require(args.size == 4)
@Suppress("UNCHECKED_CAST")
val rpcClass = Class.forName(args[0]) as Class<RPCOps>
val hostAndPort = HostAndPort.fromString(args[1])
val username = args[2]
val password = args[3]
val handle = RPCClient<RPCOps>(hostAndPort, null).start(rpcClass, username, password)
val callGenerators = rpcClass.declaredMethods.map { method ->
Generator.sequence(method.parameters.map {
generatorStore[it.type] ?: throw Exception("No generator for ${it.type}")
}).map { arguments ->
Call(method, { method.invoke(handle.proxy, *arguments.toTypedArray()) })
}
}
val callGenerator = Generator.choice(callGenerators)
val random = SplittableRandom()
while (true) {
val call = callGenerator.generateOrFail(random)
call.call()
Thread.sleep(100)
}
}
}
}

View File

@ -2,15 +2,13 @@ package net.corda.testing.node
import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.createDirectories
import net.corda.core.*
import net.corda.core.crypto.X509Utilities
import net.corda.core.crypto.commonName
import net.corda.core.div
import net.corda.core.flatMap
import net.corda.core.map
import net.corda.core.node.services.ServiceInfo
import net.corda.core.node.services.ServiceType
import net.corda.core.utilities.DUMMY_MAP
import net.corda.node.driver.addressMustNotBeBound
import net.corda.node.internal.Node
import net.corda.node.services.config.ConfigHelper
import net.corda.node.services.config.FullNodeConfiguration
@ -26,6 +24,7 @@ import org.junit.After
import org.junit.Rule
import org.junit.rules.TemporaryFolder
import java.util.*
import java.util.concurrent.Executors
import kotlin.concurrent.thread
/**
@ -53,9 +52,18 @@ abstract class NodeBasedTest {
*/
@After
fun stopAllNodes() {
val shutdownExecutor = Executors.newScheduledThreadPool(1)
nodes.forEach(Node::stop)
// Wait until ports are released
val portNotBoundChecks = nodes.flatMap {
listOf(
it.configuration.p2pAddress.let { addressMustNotBeBound(shutdownExecutor, it) },
it.configuration.rpcAddress?.let { addressMustNotBeBound(shutdownExecutor, it) }
)
}.filterNotNull()
nodes.clear()
_networkMapNode = null
Futures.allAsList(portNotBoundChecks).getOrThrow()
}
/**

View File

@ -57,7 +57,7 @@ class SimpleNode(val config: NodeConfiguration, val address: HostAndPort = freeL
},
userService)
thread(name = config.myLegalName.commonName) {
net.run()
net.run(broker.serverControl)
}
}

View File

@ -2,6 +2,7 @@ package net.corda.demobench.rpc
import com.google.common.net.HostAndPort
import net.corda.client.rpc.CordaRPCClient
import net.corda.client.rpc.CordaRPCConnection
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.utilities.loggerFor
import net.corda.demobench.model.NodeConfig
@ -17,14 +18,16 @@ class NodeRPC(config: NodeConfig, start: (NodeConfig, CordaRPCOps) -> Unit, invo
private val rpcClient = CordaRPCClient(HostAndPort.fromParts("localhost", config.rpcPort))
private val timer = Timer()
private val connections = Collections.synchronizedCollection(ArrayList<CordaRPCConnection>())
init {
val setupTask = object : TimerTask() {
override fun run() {
try {
val user = config.users.elementAt(0)
rpcClient.start(user.username, user.password)
val ops = rpcClient.proxy()
val connection = rpcClient.start(user.username, user.password)
connections.add(connection)
val ops = connection.proxy
// Cancel the "setup" task now that we've created the RPC client.
this.cancel()
@ -50,7 +53,7 @@ class NodeRPC(config: NodeConfig, start: (NodeConfig, CordaRPCOps) -> Unit, invo
override fun close() {
timer.cancel()
rpcClient.close()
connections.forEach(CordaRPCConnection::close)
}
}

View File

@ -198,20 +198,20 @@ fun main(args: Array<String>) {
// Register with alice to use alice's RPC proxy to create random events.
val aliceClient = aliceNode.rpcClientToNode()
aliceClient.start(user.username, user.password)
val aliceRPC = aliceClient.proxy()
val aliceConnection = aliceClient.start(user.username, user.password)
val aliceRPC = aliceConnection.proxy
val bobClient = bobNode.rpcClientToNode()
bobClient.start(user.username, user.password)
val bobRPC = bobClient.proxy()
val bobConnection = bobClient.start(user.username, user.password)
val bobRPC = bobConnection.proxy
val issuerClientGBP = issuerNodeGBP.rpcClientToNode()
issuerClientGBP.start(manager.username, manager.password)
val issuerRPCGBP = issuerClientGBP.proxy()
val issuerGBPConnection = issuerClientGBP.start(manager.username, manager.password)
val issuerRPCGBP = issuerGBPConnection.proxy
val issuerClientUSD = issuerNodeUSD.rpcClientToNode()
issuerClientUSD.start(manager.username, manager.password)
val issuerRPCUSD = issuerClientUSD.proxy()
val issuerUSDConnection = issuerClientUSD.start(manager.username, manager.password)
val issuerRPCUSD = issuerUSDConnection.proxy
val issuers = mapOf(USD to issuerRPCUSD, GBP to issuerRPCGBP)
@ -272,10 +272,10 @@ fun main(args: Array<String>) {
}
println("Simulation completed")
aliceClient.close()
bobClient.close()
issuerClientGBP.close()
issuerClientUSD.close()
aliceConnection.close()
bobConnection.close()
issuerGBPConnection.close()
issuerUSDConnection.close()
}
waitForAllNodesToFinish()
}

View File

@ -6,6 +6,7 @@ import com.jcraft.jsch.agentproxy.AgentProxy
import com.jcraft.jsch.agentproxy.connector.SSHAgentConnector
import com.jcraft.jsch.agentproxy.usocket.JNAUSocketFactory
import net.corda.client.rpc.CordaRPCClient
import net.corda.client.rpc.CordaRPCConnection
import net.corda.core.messaging.CordaRPCOps
import net.corda.node.driver.PortAllocation
import org.slf4j.LoggerFactory
@ -137,9 +138,9 @@ class NodeConnection(
private val rpcUsername: String,
private val rpcPassword: String
) : Closeable {
private var client: CordaRPCClient? = null
private var _proxy: CordaRPCOps? = null
val proxy: CordaRPCOps get() = _proxy ?: throw IllegalStateException("proxy requested, but the client is not running")
private val client = CordaRPCClient(localTunnelAddress)
private var connection: CordaRPCConnection? = null
val proxy: CordaRPCOps get() = connection?.proxy ?: throw IllegalStateException("proxy requested, but the client is not running")
data class ShellCommandOutput(
val originalShellCommand: String,
@ -162,32 +163,24 @@ class NodeConnection(
}
fun <A> doWhileClientStopped(action: () -> A): A {
val client = client
val proxy = _proxy
require(client != null && proxy != null) { "doWhileClientStopped called with no running client" }
val connection = connection
require(connection != null) { "doWhileClientStopped called with no running client" }
log.info("Stopping RPC proxy to $hostName, tunnel at $localTunnelAddress")
client!!.close()
connection!!.close()
try {
return action()
} finally {
log.info("Starting new RPC proxy to $hostName, tunnel at $localTunnelAddress")
val newClient = CordaRPCClient(localTunnelAddress)
// TODO expose these somehow?
newClient.start(rpcUsername, rpcPassword)
val newProxy = newClient.proxy()
this.client = newClient
this._proxy = newProxy
val newConnection = client.start(rpcUsername, rpcPassword)
this.connection = newConnection
}
}
fun startClient() {
log.info("Creating RPC proxy to $hostName, tunnel at $localTunnelAddress")
val client = CordaRPCClient(localTunnelAddress)
client.start(rpcUsername, rpcPassword)
val proxy = client.proxy()
connection = client.start(rpcUsername, rpcPassword)
log.info("Proxy created")
this.client = client
this._proxy = proxy
}
/**
@ -229,7 +222,7 @@ class NodeConnection(
}
override fun close() {
client?.close()
connection?.close()
jSchSession.disconnect()
}
}

View File

@ -199,8 +199,8 @@ class NodeWebServer(val config: WebServerConfig) {
private fun connectLocalRpcAsNodeUser(): CordaRPCOps {
log.info("Connecting to node at ${config.p2pAddress} as node user")
val client = CordaRPCClient(config.p2pAddress, config)
client.start(ArtemisMessagingComponent.NODE_USER, ArtemisMessagingComponent.NODE_USER)
return client.proxy()
val connection = client.start(ArtemisMessagingComponent.NODE_USER, ArtemisMessagingComponent.NODE_USER)
return connection.proxy
}
/** Fetch CordaPluginRegistry classes registered in META-INF/services/net.corda.core.node.CordaPluginRegistry files that exist in the classpath */