From 7924a5a83419ac8562e16dff616653c9a22ccf71 Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Tue, 6 Feb 2018 19:02:06 +0000 Subject: [PATCH] Add RPC deduplication to client and server --- .idea/compiler.xml | 3 +- .../net/corda/client/rpc/RPCStabilityTests.kt | 48 +++- .../corda/client/rpc/internal/RPCClient.kt | 11 +- .../rpc/internal/RPCClientProxyHandler.kt | 111 ++++----- .../corda/client/rpc/RPCPerformanceTests.kt | 19 +- .../main/kotlin/net/corda/nodeapi/RPCApi.kt | 58 +++-- .../nodeapi/internal/DeduplicationChecker.kt | 29 +++ .../node/services/messaging/RPCServer.kt | 226 ++++++++++-------- 8 files changed, 306 insertions(+), 199 deletions(-) create mode 100644 node-api/src/main/kotlin/net/corda/nodeapi/internal/DeduplicationChecker.kt diff --git a/.idea/compiler.xml b/.idea/compiler.xml index d8d9f1498a..dba1dad5e2 100644 --- a/.idea/compiler.xml +++ b/.idea/compiler.xml @@ -72,6 +72,7 @@ + @@ -159,4 +160,4 @@ - + \ No newline at end of file diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt index 45b5bd3d0b..de74d563cb 100644 --- a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt @@ -19,6 +19,7 @@ import org.apache.activemq.artemis.api.core.SimpleString import org.junit.After import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue +import org.junit.Ignore import org.junit.Rule import org.junit.Test import rx.Observable @@ -127,7 +128,7 @@ class RPCStabilityTests { rpcDriver { fun startAndCloseServer(broker: RpcBrokerHandle) { startRpcServerWithBrokerRunning( - configuration = RPCServerConfiguration.default.copy(consumerPoolSize = 1, producerPoolBound = 1), + configuration = RPCServerConfiguration.default, ops = DummyOps, brokerHandle = broker ).rpcServer.close() @@ -148,7 +149,7 @@ class RPCStabilityTests { @Test fun `rpc client close doesnt leak broker resources`() { rpcDriver { - val server = startRpcServer(configuration = RPCServerConfiguration.default.copy(consumerPoolSize = 1, producerPoolBound = 1), ops = DummyOps).get() + val server = startRpcServer(configuration = RPCServerConfiguration.default, ops = DummyOps).get() RPCClient(server.broker.hostAndPort!!).start(RPCOps::class.java, rpcTestUser.username, rpcTestUser.password).close() val initial = server.broker.getStats() repeat(100) { @@ -337,11 +338,12 @@ class RPCStabilityTests { val request = RPCApi.ClientToServer.RpcRequest( clientAddress = SimpleString(myQueue), methodName = SlowConsumerRPCOps::streamAtInterval.name, - serialisedArguments = listOf(10.millis, 123456).serialize(context = SerializationDefaults.RPC_SERVER_CONTEXT).bytes, + serialisedArguments = listOf(10.millis, 123456).serialize(context = SerializationDefaults.RPC_SERVER_CONTEXT), replyId = Trace.InvocationId.newInstance(), sessionId = Trace.SessionId.newInstance() ) request.writeToClientMessage(message) + message.putLongProperty(RPCApi.DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME, 0) producer.send(message) session.commit() @@ -350,6 +352,46 @@ class RPCStabilityTests { } } + interface StreamOps : RPCOps { + fun stream(streamInterval: Duration): Observable + } + class StreamOpsImpl : StreamOps { + override val protocolVersion = 0 + override fun stream(streamInterval: Duration): Observable { + return Observable.interval(streamInterval.toNanos(), TimeUnit.NANOSECONDS) + } + } + @Ignore("This is flaky as sometimes artemis delivers out of order messages after the kick") + @Test + fun `deduplication on the client side`() { + rpcDriver { + val server = startRpcServer(ops = StreamOpsImpl()).getOrThrow() + val proxy = startRpcClient( + server.broker.hostAndPort!!, + configuration = RPCClientConfiguration.default.copy( + connectionRetryInterval = 1.days // switch off failover + ) + ).getOrThrow() + // Find the internal address of the client + val clientAddress = server.broker.serverControl.addressNames.find { it.startsWith(RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX) } + val events = ArrayList() + // Start streaming an incrementing value 2000 times per second from the server. + val subscription = proxy.stream(streamInterval = Duration.ofNanos(500_000)).subscribe { + events.add(it) + } + // These sleeps are *fine*, the invariant should hold regardless of any delays + Thread.sleep(50) + // Kick the client. This seems to trigger redelivery of (presumably non-acked) messages. + server.broker.serverControl.closeConsumerConnectionsForAddress(clientAddress) + Thread.sleep(50) + subscription.unsubscribe() + for (i in 0 until events.size) { + require(events[i] == i.toLong()) { + "Events not incremental, possible duplicate, ${events[i]} != ${i.toLong()}\nExpected: ${(0..i).toList()}\nGot : $events\n" + } + } + } + } } fun RPCDriverDSL.pollUntilClientNumber(server: RpcServerHandle, expected: Int) { diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt index 756f07216e..a2520317d5 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt @@ -42,8 +42,6 @@ data class RPCClientConfiguration( val reapInterval: Duration, /** The number of threads to use for observations (for executing [Observable.onNext]) */ val observationExecutorPoolSize: Int, - /** The maximum number of producers to create to handle outgoing messages */ - val producerPoolBound: Int, /** * Determines the concurrency level of the Observable Cache. This is exposed because it implicitly determines * the limit on the number of leaked observables reaped because of garbage collection per reaping. @@ -56,9 +54,12 @@ data class RPCClientConfiguration( val connectionRetryIntervalMultiplier: Double, /** Maximum retry interval */ val connectionMaxRetryInterval: Duration, + /** Maximum reconnect attempts on failover */ val maxReconnectAttempts: Int, /** Maximum file size */ - val maxFileSize: Int + val maxFileSize: Int, + /** The cache expiry of a deduplication watermark per client. */ + val deduplicationCacheExpiry: Duration ) { companion object { val unlimitedReconnectAttempts = -1 @@ -68,14 +69,14 @@ data class RPCClientConfiguration( trackRpcCallSites = false, reapInterval = 1.seconds, observationExecutorPoolSize = 4, - producerPoolBound = 1, cacheConcurrencyLevel = 8, connectionRetryInterval = 5.seconds, connectionRetryIntervalMultiplier = 1.5, connectionMaxRetryInterval = 3.minutes, maxReconnectAttempts = unlimitedReconnectAttempts, /** 10 MiB maximum allowed file size for attachments, including message headers. TODO: acquire this value from Network Map when supported. */ - maxFileSize = 10485760 + maxFileSize = 10485760, + deduplicationCacheExpiry = 1.days ) } } diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt index 813569dc89..2f65f21e40 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt @@ -15,7 +15,6 @@ import net.corda.client.rpc.RPCSinceVersion import net.corda.core.context.Actor import net.corda.core.context.Trace import net.corda.core.context.Trace.InvocationId -import net.corda.core.internal.LazyPool import net.corda.core.internal.LazyStickyPool import net.corda.core.internal.LifeCycle import net.corda.core.internal.ThreadBox @@ -26,14 +25,12 @@ import net.corda.core.utilities.Try import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug import net.corda.core.utilities.getOrThrow -import net.corda.nodeapi.ArtemisConsumer -import net.corda.nodeapi.ArtemisProducer import net.corda.nodeapi.RPCApi +import net.corda.nodeapi.internal.DeduplicationChecker import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.SimpleString +import org.apache.activemq.artemis.api.core.client.* import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE -import org.apache.activemq.artemis.api.core.client.ClientMessage -import org.apache.activemq.artemis.api.core.client.ServerLocator import rx.Notification import rx.Observable import rx.subjects.UnicastSubject @@ -43,6 +40,7 @@ import java.time.Instant import java.util.* import java.util.concurrent.* import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.AtomicLong import kotlin.reflect.jvm.javaMethod /** @@ -111,6 +109,8 @@ class RPCClientProxyHandler( // Used for reaping private var reaperExecutor: ScheduledExecutorService? = null + // Used for sending + private var sendExecutor: ExecutorService? = null // A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering. private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").setDaemon(true).build() @@ -161,22 +161,14 @@ class RPCClientProxyHandler( build() } - // We cannot pool consumers as we need to preserve the original muxed message order. - // TODO We may need to pool these somehow anyway, otherwise if the server sends many big messages in parallel a - // single consumer may be starved for flow control credits. Recheck this once Artemis's large message streaming is - // integrated properly. - private var sessionAndConsumer: ArtemisConsumer? = null - // Pool producers to reduce contention on the client side. - private val sessionAndProducerPool = LazyPool(bound = rpcConfiguration.producerPoolBound) { - // Note how we create new sessions *and* session factories per producer. - // We cannot simply pool producers on one session because sessions are single threaded. - // We cannot simply pool sessions on one session factory because flow control credits are tied to factories, so - // sessions tend to starve each other when used concurrently. - val sessionFactory = serverLocator.createSessionFactory() - val session = sessionFactory.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) - session.start() - ArtemisProducer(sessionFactory, session, session.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME)) - } + private var sessionFactory: ClientSessionFactory? = null + private var producerSession: ClientSession? = null + private var consumerSession: ClientSession? = null + private var rpcProducer: ClientProducer? = null + private var rpcConsumer: ClientConsumer? = null + + private val deduplicationChecker = DeduplicationChecker(rpcConfiguration.deduplicationCacheExpiry) + private val deduplicationSequenceNumber = AtomicLong(0) /** * Start the client. This creates the per-client queue, starts the consumer session and the reaper. @@ -187,22 +179,25 @@ class RPCClientProxyHandler( 1, ThreadFactoryBuilder().setNameFormat("rpc-client-reaper-%d").setDaemon(true).build() ) + sendExecutor = Executors.newSingleThreadExecutor( + ThreadFactoryBuilder().setNameFormat("rpc-client-sender-%d").build() + ) reaperScheduledFuture = reaperExecutor!!.scheduleAtFixedRate( this::reapObservablesAndNotify, rpcConfiguration.reapInterval.toMillis(), rpcConfiguration.reapInterval.toMillis(), TimeUnit.MILLISECONDS ) - sessionAndProducerPool.run { - it.session.createTemporaryQueue(clientAddress, RoutingType.ANYCAST, clientAddress) - } - val sessionFactory = serverLocator.createSessionFactory() - val session = sessionFactory.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) - val consumer = session.createConsumer(clientAddress) - consumer.setMessageHandler(this@RPCClientProxyHandler::artemisMessageHandler) - sessionAndConsumer = ArtemisConsumer(sessionFactory, session, consumer) + sessionFactory = serverLocator.createSessionFactory() + producerSession = sessionFactory!!.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + rpcProducer = producerSession!!.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME) + consumerSession = sessionFactory!!.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + consumerSession!!.createTemporaryQueue(clientAddress, RoutingType.ANYCAST, clientAddress) + rpcConsumer = consumerSession!!.createConsumer(clientAddress) + rpcConsumer!!.setMessageHandler(this::artemisMessageHandler) lifeCycle.transition(State.UNSTARTED, State.SERVER_VERSION_NOT_SET) - session.start() + consumerSession!!.start() + producerSession!!.start() } // This is the general function that transforms a client side RPC to internal Artemis messages. @@ -212,7 +207,7 @@ class RPCClientProxyHandler( if (method == toStringMethod) { return "Client RPC proxy for $rpcOpsClass" } - if (sessionAndConsumer!!.session.isClosed) { + if (consumerSession!!.isClosed) { throw RPCException("RPC Proxy is closed") } @@ -220,23 +215,20 @@ class RPCClientProxyHandler( callSiteMap?.set(replyId, Throwable("")) try { val serialisedArguments = (arguments?.toList() ?: emptyList()).serialize(context = serializationContextWithObservableContext) - val request = RPCApi.ClientToServer.RpcRequest(clientAddress, method.name, serialisedArguments.bytes, replyId, sessionId, externalTrace, impersonatedActor) + val request = RPCApi.ClientToServer.RpcRequest( + clientAddress, + method.name, + serialisedArguments, + replyId, + sessionId, + externalTrace, + impersonatedActor + ) val replyFuture = SettableFuture.create() - sessionAndProducerPool.run { - val message = it.session.createMessage(false) - request.writeToClientMessage(message) - - log.debug { - val argumentsString = arguments?.joinToString() ?: "" - "-> RPC(${replyId.value}) -> ${method.name}($argumentsString): ${method.returnType}" - } - - require(rpcReplyMap.put(replyId, replyFuture) == null) { - "Generated several RPC requests with same ID $replyId" - } - it.producer.send(message) - it.session.commit() + require(rpcReplyMap.put(replyId, replyFuture) == null) { + "Generated several RPC requests with same ID $replyId" } + sendMessage(request) return replyFuture.getOrThrow() } catch (e: RuntimeException) { // Already an unchecked exception, so just rethrow it @@ -249,9 +241,24 @@ class RPCClientProxyHandler( } } + private fun sendMessage(message: RPCApi.ClientToServer) { + val artemisMessage = producerSession!!.createMessage(false) + message.writeToClientMessage(artemisMessage) + sendExecutor!!.submit { + artemisMessage.putLongProperty(RPCApi.DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME, deduplicationSequenceNumber.getAndIncrement()) + log.debug { "-> RPC -> $message" } + rpcProducer!!.send(artemisMessage) + } + } + // The handler for Artemis messages. private fun artemisMessageHandler(message: ClientMessage) { val serverToClient = RPCApi.ServerToClient.fromClientMessage(serializationContextWithObservableContext, message) + val deduplicationSequenceNumber = message.getLongProperty(RPCApi.DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME) + if (deduplicationChecker.checkDuplicateMessageId(serverToClient.deduplicationIdentity, deduplicationSequenceNumber)) { + log.info("Message duplication detected, discarding message") + return + } log.debug { "Got message from RPC server $serverToClient" } when (serverToClient) { is RPCApi.ServerToClient.RpcReply -> { @@ -325,14 +332,12 @@ class RPCClientProxyHandler( * @param notify whether to notify observables or not. */ private fun close(notify: Boolean = true) { - sessionAndConsumer?.sessionFactory?.close() + sessionFactory?.close() reaperScheduledFuture?.cancel(false) observableContext.observableMap.invalidateAll() reapObservables(notify) reaperExecutor?.shutdownNow() - sessionAndProducerPool.close().forEach { - it.sessionFactory.close() - } + sendExecutor?.shutdownNow() // Note the ordering is important, we shut down the consumer *before* the observation executor, otherwise we may // leak borrowed executors. val observationExecutors = observationExecutorPool.close() @@ -385,11 +390,7 @@ class RPCClientProxyHandler( } if (observableIds != null) { log.debug { "Reaping ${observableIds.size} observables" } - sessionAndProducerPool.run { - val message = it.session.createMessage(false) - RPCApi.ClientToServer.ObservablesClosed(observableIds).writeToClientMessage(message) - it.producer.send(message) - } + sendMessage(RPCApi.ClientToServer.ObservablesClosed(observableIds)) } } } diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt index b945487cb9..689674d22f 100644 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt @@ -3,6 +3,7 @@ package net.corda.client.rpc import com.google.common.base.Stopwatch import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.core.messaging.RPCOps +import net.corda.core.utilities.days import net.corda.core.utilities.minutes import net.corda.core.utilities.seconds import net.corda.node.services.messaging.RPCServerConfiguration @@ -87,13 +88,10 @@ class RPCPerformanceTests : AbstractRPCTest() { val proxy = testProxy( RPCClientConfiguration.default.copy( cacheConcurrencyLevel = 16, - observationExecutorPoolSize = 2, - producerPoolBound = 2 + observationExecutorPoolSize = 2 ), RPCServerConfiguration.default.copy( - rpcThreadPoolSize = 8, - consumerPoolSize = 2, - producerPoolBound = 8 + rpcThreadPoolSize = 8 ) ) @@ -130,13 +128,10 @@ class RPCPerformanceTests : AbstractRPCTest() { val proxy = testProxy( RPCClientConfiguration.default.copy( reapInterval = 1.seconds, - cacheConcurrencyLevel = 16, - producerPoolBound = 8 + cacheConcurrencyLevel = 16 ), RPCServerConfiguration.default.copy( - rpcThreadPoolSize = 8, - consumerPoolSize = 1, - producerPoolBound = 8 + rpcThreadPoolSize = 8 ) ) startPublishingFixedRateInjector( @@ -165,9 +160,7 @@ class RPCPerformanceTests : AbstractRPCTest() { rpcDriver { val proxy = testProxy( RPCClientConfiguration.default, - RPCServerConfiguration.default.copy( - consumerPoolSize = 1 - ) + RPCServerConfiguration.default ) val numberOfMessages = 1000 val bigSize = 10_000_000 diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt b/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt index 96bd8c0560..338bac1bfd 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt @@ -10,6 +10,7 @@ import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.core.utilities.Id +import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.Try import org.apache.activemq.artemis.api.core.ActiveMQBuffer import org.apache.activemq.artemis.api.core.SimpleString @@ -72,6 +73,9 @@ object RPCApi { const val RPC_CLIENT_BINDING_ADDITIONS = "rpc.clientqueueadditions" const val RPC_TARGET_LEGAL_IDENTITY = "rpc-target-legal-identity" + const val DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME = "deduplication-sequence-number" + + val RPC_CLIENT_BINDING_REMOVAL_FILTER_EXPRESSION = "${ManagementHelper.HDR_NOTIFICATION_TYPE} = '${CoreNotificationType.BINDING_REMOVED.name}' AND " + "${ManagementHelper.HDR_ROUTING_NAME} LIKE '$RPC_CLIENT_QUEUE_NAME_PREFIX.%'" @@ -94,6 +98,8 @@ object RPCApi { OBSERVABLES_CLOSED } + abstract fun writeToClientMessage(message: ClientMessage) + /** * Request to a server to trigger the specified method with the provided arguments. * @@ -105,13 +111,13 @@ object RPCApi { data class RpcRequest( val clientAddress: SimpleString, val methodName: String, - val serialisedArguments: ByteArray, + val serialisedArguments: OpaqueBytes, val replyId: InvocationId, val sessionId: SessionId, val externalTrace: Trace? = null, val impersonatedActor: Actor? = null ) : ClientToServer() { - fun writeToClientMessage(message: ClientMessage) { + override fun writeToClientMessage(message: ClientMessage) { MessageUtil.setJMSReplyTo(message, clientAddress) message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REQUEST.ordinal) @@ -122,12 +128,12 @@ object RPCApi { impersonatedActor?.mapToImpersonated(message) message.putStringProperty(METHOD_NAME_FIELD_NAME, methodName) - message.bodyBuffer.writeBytes(serialisedArguments) + message.bodyBuffer.writeBytes(serialisedArguments.bytes) } } data class ObservablesClosed(val ids: List) : ClientToServer() { - fun writeToClientMessage(message: ClientMessage) { + override fun writeToClientMessage(message: ClientMessage) { message.putIntProperty(TAG_FIELD_NAME, Tag.OBSERVABLES_CLOSED.ordinal) val buffer = message.bodyBuffer buffer.writeInt(ids.size) @@ -144,7 +150,7 @@ object RPCApi { RPCApi.ClientToServer.Tag.RPC_REQUEST -> RpcRequest( clientAddress = MessageUtil.getJMSReplyTo(message), methodName = message.getStringProperty(METHOD_NAME_FIELD_NAME), - serialisedArguments = message.getBodyAsByteArray(), + serialisedArguments = OpaqueBytes(message.getBodyAsByteArray()), replyId = message.replyId(), sessionId = message.sessionId(), externalTrace = message.externalTrace(), @@ -175,13 +181,21 @@ object RPCApi { abstract fun writeToClientMessage(context: SerializationContext, message: ClientMessage) - /** Reply in response to an [ClientToServer.RpcRequest]. */ + abstract val deduplicationIdentity: String + + /** + * Reply in response to an [ClientToServer.RpcRequest]. + * @property deduplicationSequenceNumber a sequence number strictly incrementing with each message. Use this for + * duplicate detection on the client. + */ data class RpcReply( val id: InvocationId, - val result: Try + val result: Try, + override val deduplicationIdentity: String ) : ServerToClient() { override fun writeToClientMessage(context: SerializationContext, message: ClientMessage) { message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REPLY.ordinal) + message.putStringProperty(DEDUPLICATION_IDENTITY_FIELD_NAME, deduplicationIdentity) id.mapTo(message, RPC_ID_FIELD_NAME, RPC_ID_TIMESTAMP_FIELD_NAME) message.bodyBuffer.writeBytes(result.safeSerialize(context) { Try.Failure(it) }.bytes) } @@ -189,10 +203,12 @@ object RPCApi { data class Observation( val id: InvocationId, - val content: Notification<*> + val content: Notification<*>, + override val deduplicationIdentity: String ) : ServerToClient() { override fun writeToClientMessage(context: SerializationContext, message: ClientMessage) { message.putIntProperty(TAG_FIELD_NAME, Tag.OBSERVATION.ordinal) + message.putStringProperty(DEDUPLICATION_IDENTITY_FIELD_NAME, deduplicationIdentity) id.mapTo(message, OBSERVABLE_ID_FIELD_NAME, OBSERVABLE_ID_TIMESTAMP_FIELD_NAME) message.bodyBuffer.writeBytes(content.safeSerialize(context) { Notification.createOnError(it) }.bytes) } @@ -207,17 +223,26 @@ object RPCApi { fun fromClientMessage(context: SerializationContext, message: ClientMessage): ServerToClient { val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)] + val deduplicationIdentity = message.getStringProperty(DEDUPLICATION_IDENTITY_FIELD_NAME) return when (tag) { RPCApi.ServerToClient.Tag.RPC_REPLY -> { val id = message.invocationId(RPC_ID_FIELD_NAME, RPC_ID_TIMESTAMP_FIELD_NAME) ?: throw IllegalStateException("Cannot parse invocation id from client message.") val poolWithIdContext = context.withProperty(RpcRequestOrObservableIdKey, id) - RpcReply(id, message.getBodyAsByteArray().deserialize(context = poolWithIdContext)) + RpcReply( + id = id, + deduplicationIdentity = deduplicationIdentity, + result = message.getBodyAsByteArray().deserialize(context = poolWithIdContext) + ) } RPCApi.ServerToClient.Tag.OBSERVATION -> { val observableId = message.invocationId(OBSERVABLE_ID_FIELD_NAME, OBSERVABLE_ID_TIMESTAMP_FIELD_NAME) ?: throw IllegalStateException("Cannot parse invocation id from client message.") val poolWithIdContext = context.withProperty(RpcRequestOrObservableIdKey, observableId) val payload = message.getBodyAsByteArray().deserialize>(context = poolWithIdContext) - Observation(observableId, payload) + Observation( + id = observableId, + deduplicationIdentity = deduplicationIdentity, + content = payload + ) } } } @@ -225,18 +250,6 @@ object RPCApi { } } -data class ArtemisProducer( - val sessionFactory: ClientSessionFactory, - val session: ClientSession, - val producer: ClientProducer -) - -data class ArtemisConsumer( - val sessionFactory: ClientSessionFactory, - val session: ClientSession, - val consumer: ClientConsumer -) - private val TAG_FIELD_NAME = "tag" private val RPC_ID_FIELD_NAME = "rpc-id" private val RPC_ID_TIMESTAMP_FIELD_NAME = "rpc-id-timestamp" @@ -249,6 +262,7 @@ private val RPC_EXTERNAL_SESSION_ID_TIMESTAMP_FIELD_NAME = "rpc-external-session private val RPC_IMPERSONATED_ACTOR_ID = "rpc-impersonated-actor-id" private val RPC_IMPERSONATED_ACTOR_STORE_ID = "rpc-impersonated-actor-store-id" private val RPC_IMPERSONATED_ACTOR_OWNING_LEGAL_IDENTITY = "rpc-impersonated-actor-owningLegalIdentity" +private val DEDUPLICATION_IDENTITY_FIELD_NAME = "deduplication-identity" private val OBSERVABLE_ID_FIELD_NAME = "observable-id" private val OBSERVABLE_ID_TIMESTAMP_FIELD_NAME = "observable-id-timestamp" private val METHOD_NAME_FIELD_NAME = "method-name" diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/DeduplicationChecker.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/DeduplicationChecker.kt new file mode 100644 index 0000000000..2fc69bbd1e --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/DeduplicationChecker.kt @@ -0,0 +1,29 @@ +package net.corda.nodeapi.internal + +import com.google.common.cache.CacheBuilder +import com.google.common.cache.CacheLoader +import java.time.Duration +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicLong + +/** + * A class allowing the deduplication of a strictly incrementing sequence number. + */ +class DeduplicationChecker(cacheExpiry: Duration) { + // dedupe identity -> watermark cache + private val watermarkCache = CacheBuilder.newBuilder() + .expireAfterAccess(cacheExpiry.toNanos(), TimeUnit.NANOSECONDS) + .build(WatermarkCacheLoader) + + private object WatermarkCacheLoader : CacheLoader() { + override fun load(key: Any) = AtomicLong(-1) + } + + /** + * @param identity the identity that generates the sequence numbers. + * @param sequenceNumber the sequence number to check. + */ + fun checkDuplicateMessageId(identity: Any, sequenceNumber: Long): Boolean { + return watermarkCache[identity].getAndUpdate { maxOf(sequenceNumber, it) } >= sequenceNumber + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt index 0e06674aa0..eca9ce1601 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt @@ -18,9 +18,7 @@ import net.corda.core.context.InvocationContext import net.corda.core.context.Trace import net.corda.core.context.Trace.InvocationId import net.corda.core.identity.CordaX500Name -import net.corda.core.internal.LazyStickyPool import net.corda.core.internal.LifeCycle -import net.corda.core.internal.join import net.corda.core.messaging.RPCOps import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationDefaults.RPC_SERVER_CONTEXT @@ -29,14 +27,14 @@ import net.corda.core.utilities.* import net.corda.node.internal.security.AuthorizingSubject import net.corda.node.internal.security.RPCSecurityManager import net.corda.node.services.logging.pushToLoggingContext -import net.corda.nodeapi.* +import net.corda.nodeapi.RPCApi +import net.corda.nodeapi.externalTrace +import net.corda.nodeapi.impersonatedActor +import net.corda.nodeapi.internal.DeduplicationChecker import org.apache.activemq.artemis.api.core.Message import org.apache.activemq.artemis.api.core.SimpleString +import org.apache.activemq.artemis.api.core.client.* import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE -import org.apache.activemq.artemis.api.core.client.ClientConsumer -import org.apache.activemq.artemis.api.core.client.ClientMessage -import org.apache.activemq.artemis.api.core.client.ClientSession -import org.apache.activemq.artemis.api.core.client.ServerLocator import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl import org.apache.activemq.artemis.api.core.management.CoreNotificationType import org.apache.activemq.artemis.api.core.management.ManagementHelper @@ -49,24 +47,26 @@ import rx.Subscription import java.lang.reflect.InvocationTargetException import java.lang.reflect.Method import java.time.Duration +import java.util.* import java.util.concurrent.* +import kotlin.concurrent.thread data class RPCServerConfiguration( /** The number of threads to use for handling RPC requests */ val rpcThreadPoolSize: Int, - /** The number of consumers to handle incoming messages */ - val consumerPoolSize: Int, - /** The maximum number of producers to create to handle outgoing messages */ - val producerPoolBound: Int, /** The interval of subscription reaping */ - val reapInterval: Duration + val reapInterval: Duration, + /** The cache expiry of a deduplication watermark per client. */ + val deduplicationCacheExpiry: Duration, + /** The size of the send queue */ + val sendJobQueueSize: Int ) { companion object { val default = RPCServerConfiguration( rpcThreadPoolSize = 4, - consumerPoolSize = 2, - producerPoolBound = 4, - reapInterval = 1.seconds + reapInterval = 1.seconds, + deduplicationCacheExpiry = 1.days, + sendJobQueueSize = 256 ) } } @@ -115,22 +115,24 @@ class RPCServer( /** The scheduled reaper handle. */ private var reaperScheduledFuture: ScheduledFuture<*>? = null - private var observationSendExecutor: ExecutorService? = null + private var senderThread: Thread? = null private var rpcExecutor: ScheduledExecutorService? = null private var reaperExecutor: ScheduledExecutorService? = null - private val sessionAndConsumers = ArrayList(rpcConfiguration.consumerPoolSize) - private val sessionAndProducerPool = LazyStickyPool(rpcConfiguration.producerPoolBound) { - val sessionFactory = serverLocator.createSessionFactory() - val session = sessionFactory.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) - session.start() - ArtemisProducer(sessionFactory, session, session.createProducer()) - } + private var sessionFactory: ClientSessionFactory? = null + private var producerSession: ClientSession? = null + private var consumerSession: ClientSession? = null + private var rpcProducer: ClientProducer? = null + private var rpcConsumer: ClientConsumer? = null private var clientBindingRemovalConsumer: ClientConsumer? = null private var clientBindingAdditionConsumer: ClientConsumer? = null private var serverControl: ActiveMQServerControl? = null private val responseMessageBuffer = ConcurrentHashMap() + private val sendJobQueue = ArrayBlockingQueue(rpcConfiguration.sendJobQueueSize) + + private val deduplicationChecker = DeduplicationChecker(rpcConfiguration.deduplicationCacheExpiry) + private var deduplicationIdentity: String? = null init { val groupedMethods = ops.javaClass.declaredMethods.groupBy { it.name } @@ -154,16 +156,12 @@ class RPCServer( try { lifeCycle.requireState(State.UNSTARTED) log.info("Starting RPC server with configuration $rpcConfiguration") - observationSendExecutor = Executors.newFixedThreadPool( - 1, - ThreadFactoryBuilder().setNameFormat("rpc-observation-sender-%d").build() - ) + senderThread = startSenderThread() rpcExecutor = Executors.newScheduledThreadPool( rpcConfiguration.rpcThreadPoolSize, ThreadFactoryBuilder().setNameFormat("rpc-server-handler-pool-%d").build() ) - reaperExecutor = Executors.newScheduledThreadPool( - 1, + reaperExecutor = Executors.newSingleThreadScheduledExecutor( ThreadFactoryBuilder().setNameFormat("rpc-server-reaper-%d").build() ) reaperScheduledFuture = reaperExecutor!!.scheduleAtFixedRate( @@ -172,55 +170,77 @@ class RPCServer( rpcConfiguration.reapInterval.toMillis(), TimeUnit.MILLISECONDS ) - val sessions = createConsumerSessions() - createNotificationConsumers() + + sessionFactory = serverLocator.createSessionFactory() + producerSession = sessionFactory!!.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + createRpcProducer(producerSession!!) + consumerSession = sessionFactory!!.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + createRpcConsumer(consumerSession!!) + createNotificationConsumers(consumerSession!!) serverControl = activeMqServerControl + deduplicationIdentity = UUID.randomUUID().toString() lifeCycle.transition(State.UNSTARTED, State.STARTED) // We delay the consumer session start because Artemis starts delivering messages immediately, so we need to be // fully initialised. - sessions.forEach { - it.start() - } + producerSession!!.start() + consumerSession!!.start() } catch (exception: Throwable) { close() throw exception } } - private fun createConsumerSessions(): ArrayList { - val sessions = ArrayList() - for (i in 1..rpcConfiguration.consumerPoolSize) { - val sessionFactory = serverLocator.createSessionFactory() - val session = sessionFactory.createSession(rpcServerUsername, rpcServerPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) - val consumer = session.createConsumer(RPCApi.RPC_SERVER_QUEUE_NAME) - consumer.setMessageHandler(this@RPCServer::clientArtemisMessageHandler) - sessionAndConsumers.add(ArtemisConsumer(sessionFactory, session, consumer)) - sessions.add(session) - } - return sessions + private fun createRpcProducer(producerSession: ClientSession) { + rpcProducer = producerSession.createProducer() } - private fun createNotificationConsumers() { - clientBindingRemovalConsumer = sessionAndConsumers[0].session.createConsumer(RPCApi.RPC_CLIENT_BINDING_REMOVALS) + private fun createRpcConsumer(consumerSession: ClientSession) { + rpcConsumer = consumerSession.createConsumer(RPCApi.RPC_SERVER_QUEUE_NAME) + rpcConsumer!!.setMessageHandler(this::clientArtemisMessageHandler) + } + + private fun createNotificationConsumers(consumerSession: ClientSession) { + clientBindingRemovalConsumer = consumerSession.createConsumer(RPCApi.RPC_CLIENT_BINDING_REMOVALS) clientBindingRemovalConsumer!!.setMessageHandler(this::bindingRemovalArtemisMessageHandler) - clientBindingAdditionConsumer = sessionAndConsumers[0].session.createConsumer(RPCApi.RPC_CLIENT_BINDING_ADDITIONS) + clientBindingAdditionConsumer = consumerSession.createConsumer(RPCApi.RPC_CLIENT_BINDING_ADDITIONS) clientBindingAdditionConsumer!!.setMessageHandler(this::bindingAdditionArtemisMessageHandler) } + private fun startSenderThread(): Thread { + return thread(name = "rpc-server-sender", isDaemon = true) { + var deduplicationSequenceNumber = 0L + while (true) { + val job = sendJobQueue.poll() + when (job) { + is RpcSendJob.Send -> handleSendJob(deduplicationSequenceNumber++, job) + RpcSendJob.Stop -> return@thread + } + } + } + } + + private fun handleSendJob(sequenceNumber: Long, job: RpcSendJob.Send) { + try { + job.artemisMessage.putLongProperty(RPCApi.DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME, sequenceNumber) + rpcProducer!!.send(job.clientAddress, job.artemisMessage) + log.debug { "<- RPC <- ${job.originalMessage}" } + } catch (throwable: Throwable) { + log.error("Failed to send message, kicking client. Message was ${job.originalMessage}", throwable) + serverControl!!.closeConsumerConnectionsForAddress(job.clientAddress.toString()) + invalidateClient(job.clientAddress) + } + } + fun close() { - observationSendExecutor?.join() + sendJobQueue.put(RpcSendJob.Stop) + senderThread?.join() reaperScheduledFuture?.cancel(false) rpcExecutor?.shutdownNow() reaperExecutor?.shutdownNow() securityManager.close() - sessionAndConsumers.forEach { - it.sessionFactory.close() - } + sessionFactory?.close() observableMap.invalidateAll() reapSubscriptions() - sessionAndProducerPool.close().forEach { - it.sessionFactory.close() - } lifeCycle.justTransition(State.FINISHED) } @@ -273,6 +293,14 @@ class RPCServer( log.debug { "-> RPC -> $clientToServer" } when (clientToServer) { is RPCApi.ClientToServer.RpcRequest -> { + val deduplicationSequenceNumber = artemisMessage.getLongProperty(RPCApi.DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME) + if (deduplicationChecker.checkDuplicateMessageId( + identity = clientToServer.clientAddress, + sequenceNumber = deduplicationSequenceNumber + )) { + log.info("Message duplication detected, discarding message") + return + } val arguments = Try.on { clientToServer.serialisedArguments.deserialize>(context = RPC_SERVER_CONTEXT) } @@ -316,15 +344,16 @@ class RPCServer( } private fun sendReply(replyId: InvocationId, clientAddress: SimpleString, result: Try) { - val reply = RPCApi.ServerToClient.RpcReply(replyId, result) + val reply = RPCApi.ServerToClient.RpcReply( + id = replyId, + result = result, + deduplicationIdentity = deduplicationIdentity!! + ) val observableContext = ObservableContext( - replyId, observableMap, clientAddressToObservables, - clientAddress, - serverControl!!, - sessionAndProducerPool, - observationSendExecutor!! + deduplicationIdentity!!, + clientAddress ) val buffered = bufferIfQueueNotBound(clientAddress, reply, observableContext) @@ -370,6 +399,34 @@ class RPCServer( val targetLegalIdentity = message.getStringProperty(RPCApi.RPC_TARGET_LEGAL_IDENTITY)?.let(CordaX500Name.Companion::parse) ?: nodeLegalName return Pair(Actor(Id(validatedUser), securityManager.id, targetLegalIdentity), securityManager.buildSubject(validatedUser)) } + + // We construct an observable context on each RPC request. If subsequently a nested Observable is + // encountered this same context is propagated by the instrumented KryoPool. This way all + // observations rooted in a single RPC will be muxed correctly. Note that the context construction + // itself is quite cheap. + inner class ObservableContext( + val observableMap: ObservableSubscriptionMap, + val clientAddressToObservables: SetMultimap, + val deduplicationIdentity: String, + val clientAddress: SimpleString + ) { + private val serializationContextWithObservableContext = RpcServerObservableSerializer.createContext(this) + + fun sendMessage(serverToClient: RPCApi.ServerToClient) { + val artemisMessage = producerSession!!.createMessage(false) + serverToClient.writeToClientMessage(serializationContextWithObservableContext, artemisMessage) + sendJobQueue.put(RpcSendJob.Send(clientAddress, artemisMessage, serverToClient)) + } + } + + private sealed class RpcSendJob { + data class Send( + val clientAddress: SimpleString, + val artemisMessage: ClientMessage, + val originalMessage: RPCApi.ServerToClient + ) : RpcSendJob() + object Stop : RpcSendJob() + } } // TODO replace this by creating a new CordaRPCImpl for each request, passing the context, after we fix Shell and WebServer @@ -417,45 +474,11 @@ class ObservableSubscription( typealias ObservableSubscriptionMap = Cache -// We construct an observable context on each RPC request. If subsequently a nested Observable is -// encountered this same context is propagated by the instrumented KryoPool. This way all -// observations rooted in a single RPC will be muxed correctly. Note that the context construction -// itself is quite cheap. -class ObservableContext( - val invocationId: InvocationId, - val observableMap: ObservableSubscriptionMap, - val clientAddressToObservables: SetMultimap, - val clientAddress: SimpleString, - val serverControl: ActiveMQServerControl, - val sessionAndProducerPool: LazyStickyPool, - val observationSendExecutor: ExecutorService -) { - private companion object { - private val log = contextLogger() - } - - private val serializationContextWithObservableContext = RpcServerObservableSerializer.createContext(this) - - fun sendMessage(serverToClient: RPCApi.ServerToClient) { - try { - sessionAndProducerPool.run(invocationId) { - val artemisMessage = it.session.createMessage(false) - serverToClient.writeToClientMessage(serializationContextWithObservableContext, artemisMessage) - it.producer.send(clientAddress, artemisMessage) - log.debug("<- RPC <- $serverToClient") - } - } catch (throwable: Throwable) { - log.error("Failed to send message, kicking client. Message was $serverToClient", throwable) - serverControl.closeConsumerConnectionsForAddress(clientAddress.toString()) - } - } -} - object RpcServerObservableSerializer : Serializer>() { private object RpcObservableContextKey private val log = LoggerFactory.getLogger(javaClass) - fun createContext(observableContext: ObservableContext): SerializationContext { + fun createContext(observableContext: RPCServer.ObservableContext): SerializationContext { return RPC_SERVER_CONTEXT.withProperty(RpcServerObservableSerializer.RpcObservableContextKey, observableContext) } @@ -465,7 +488,7 @@ object RpcServerObservableSerializer : Serializer>() { override fun write(kryo: Kryo, output: Output, observable: Observable<*>) { val observableId = InvocationId.newInstance() - val observableContext = kryo.context[RpcObservableContextKey] as ObservableContext + val observableContext = kryo.context[RpcObservableContextKey] as RPCServer.ObservableContext output.writeInvocationId(observableId) val observableWithSubscription = ObservableSubscription( // We capture [observableContext] in the subscriber. Note that all synchronisation/kryo borrowing @@ -474,9 +497,12 @@ object RpcServerObservableSerializer : Serializer>() { object : Subscriber>() { override fun onNext(observation: Notification<*>) { if (!isUnsubscribed) { - observableContext.observationSendExecutor.submit { - observableContext.sendMessage(RPCApi.ServerToClient.Observation(observableId, observation)) - } + val message = RPCApi.ServerToClient.Observation( + id = observableId, + content = observation, + deduplicationIdentity = observableContext.deduplicationIdentity + ) + observableContext.sendMessage(message) } }