Merge pull request #2477 from corda/aslemmer-corda/issues/2300

Add RPC deduplication to client and server
This commit is contained in:
Andras Slemmer
2018-02-19 16:25:36 +00:00
committed by GitHub
8 changed files with 376 additions and 221 deletions

View File

@ -10,6 +10,7 @@ import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.Id
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.Try
import org.apache.activemq.artemis.api.core.ActiveMQBuffer
import org.apache.activemq.artemis.api.core.SimpleString
@ -72,6 +73,8 @@ object RPCApi {
const val RPC_CLIENT_BINDING_ADDITIONS = "rpc.clientqueueadditions"
const val RPC_TARGET_LEGAL_IDENTITY = "rpc-target-legal-identity"
const val DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME = "deduplication-sequence-number"
val RPC_CLIENT_BINDING_REMOVAL_FILTER_EXPRESSION =
"${ManagementHelper.HDR_NOTIFICATION_TYPE} = '${CoreNotificationType.BINDING_REMOVED.name}' AND " +
"${ManagementHelper.HDR_ROUTING_NAME} LIKE '$RPC_CLIENT_QUEUE_NAME_PREFIX.%'"
@ -94,6 +97,8 @@ object RPCApi {
OBSERVABLES_CLOSED
}
abstract fun writeToClientMessage(message: ClientMessage)
/**
* Request to a server to trigger the specified method with the provided arguments.
*
@ -105,13 +110,13 @@ object RPCApi {
data class RpcRequest(
val clientAddress: SimpleString,
val methodName: String,
val serialisedArguments: ByteArray,
val serialisedArguments: OpaqueBytes,
val replyId: InvocationId,
val sessionId: SessionId,
val externalTrace: Trace? = null,
val impersonatedActor: Actor? = null
) : ClientToServer() {
fun writeToClientMessage(message: ClientMessage) {
override fun writeToClientMessage(message: ClientMessage) {
MessageUtil.setJMSReplyTo(message, clientAddress)
message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REQUEST.ordinal)
@ -122,12 +127,12 @@ object RPCApi {
impersonatedActor?.mapToImpersonated(message)
message.putStringProperty(METHOD_NAME_FIELD_NAME, methodName)
message.bodyBuffer.writeBytes(serialisedArguments)
message.bodyBuffer.writeBytes(serialisedArguments.bytes)
}
}
data class ObservablesClosed(val ids: List<InvocationId>) : ClientToServer() {
fun writeToClientMessage(message: ClientMessage) {
override fun writeToClientMessage(message: ClientMessage) {
message.putIntProperty(TAG_FIELD_NAME, Tag.OBSERVABLES_CLOSED.ordinal)
val buffer = message.bodyBuffer
buffer.writeInt(ids.size)
@ -144,7 +149,7 @@ object RPCApi {
RPCApi.ClientToServer.Tag.RPC_REQUEST -> RpcRequest(
clientAddress = MessageUtil.getJMSReplyTo(message),
methodName = message.getStringProperty(METHOD_NAME_FIELD_NAME),
serialisedArguments = message.getBodyAsByteArray(),
serialisedArguments = OpaqueBytes(message.getBodyAsByteArray()),
replyId = message.replyId(),
sessionId = message.sessionId(),
externalTrace = message.externalTrace(),
@ -175,13 +180,20 @@ object RPCApi {
abstract fun writeToClientMessage(context: SerializationContext, message: ClientMessage)
/** Reply in response to an [ClientToServer.RpcRequest]. */
/** The identity used to identify the deduplication ID sequence. This should be unique per server JVM run */
abstract val deduplicationIdentity: String
/**
* Reply in response to an [ClientToServer.RpcRequest].
*/
data class RpcReply(
val id: InvocationId,
val result: Try<Any?>
val result: Try<Any?>,
override val deduplicationIdentity: String
) : ServerToClient() {
override fun writeToClientMessage(context: SerializationContext, message: ClientMessage) {
message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REPLY.ordinal)
message.putStringProperty(DEDUPLICATION_IDENTITY_FIELD_NAME, deduplicationIdentity)
id.mapTo(message, RPC_ID_FIELD_NAME, RPC_ID_TIMESTAMP_FIELD_NAME)
message.bodyBuffer.writeBytes(result.safeSerialize(context) { Try.Failure<Any>(it) }.bytes)
}
@ -189,10 +201,12 @@ object RPCApi {
data class Observation(
val id: InvocationId,
val content: Notification<*>
val content: Notification<*>,
override val deduplicationIdentity: String
) : ServerToClient() {
override fun writeToClientMessage(context: SerializationContext, message: ClientMessage) {
message.putIntProperty(TAG_FIELD_NAME, Tag.OBSERVATION.ordinal)
message.putStringProperty(DEDUPLICATION_IDENTITY_FIELD_NAME, deduplicationIdentity)
id.mapTo(message, OBSERVABLE_ID_FIELD_NAME, OBSERVABLE_ID_TIMESTAMP_FIELD_NAME)
message.bodyBuffer.writeBytes(content.safeSerialize(context) { Notification.createOnError<Void?>(it) }.bytes)
}
@ -207,17 +221,26 @@ object RPCApi {
fun fromClientMessage(context: SerializationContext, message: ClientMessage): ServerToClient {
val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)]
val deduplicationIdentity = message.getStringProperty(DEDUPLICATION_IDENTITY_FIELD_NAME)
return when (tag) {
RPCApi.ServerToClient.Tag.RPC_REPLY -> {
val id = message.invocationId(RPC_ID_FIELD_NAME, RPC_ID_TIMESTAMP_FIELD_NAME) ?: throw IllegalStateException("Cannot parse invocation id from client message.")
val poolWithIdContext = context.withProperty(RpcRequestOrObservableIdKey, id)
RpcReply(id, message.getBodyAsByteArray().deserialize(context = poolWithIdContext))
RpcReply(
id = id,
deduplicationIdentity = deduplicationIdentity,
result = message.getBodyAsByteArray().deserialize(context = poolWithIdContext)
)
}
RPCApi.ServerToClient.Tag.OBSERVATION -> {
val observableId = message.invocationId(OBSERVABLE_ID_FIELD_NAME, OBSERVABLE_ID_TIMESTAMP_FIELD_NAME) ?: throw IllegalStateException("Cannot parse invocation id from client message.")
val poolWithIdContext = context.withProperty(RpcRequestOrObservableIdKey, observableId)
val payload = message.getBodyAsByteArray().deserialize<Notification<*>>(context = poolWithIdContext)
Observation(observableId, payload)
Observation(
id = observableId,
deduplicationIdentity = deduplicationIdentity,
content = payload
)
}
}
}
@ -225,18 +248,6 @@ object RPCApi {
}
}
data class ArtemisProducer(
val sessionFactory: ClientSessionFactory,
val session: ClientSession,
val producer: ClientProducer
)
data class ArtemisConsumer(
val sessionFactory: ClientSessionFactory,
val session: ClientSession,
val consumer: ClientConsumer
)
private val TAG_FIELD_NAME = "tag"
private val RPC_ID_FIELD_NAME = "rpc-id"
private val RPC_ID_TIMESTAMP_FIELD_NAME = "rpc-id-timestamp"
@ -249,6 +260,7 @@ private val RPC_EXTERNAL_SESSION_ID_TIMESTAMP_FIELD_NAME = "rpc-external-session
private val RPC_IMPERSONATED_ACTOR_ID = "rpc-impersonated-actor-id"
private val RPC_IMPERSONATED_ACTOR_STORE_ID = "rpc-impersonated-actor-store-id"
private val RPC_IMPERSONATED_ACTOR_OWNING_LEGAL_IDENTITY = "rpc-impersonated-actor-owningLegalIdentity"
private val DEDUPLICATION_IDENTITY_FIELD_NAME = "deduplication-identity"
private val OBSERVABLE_ID_FIELD_NAME = "observable-id"
private val OBSERVABLE_ID_TIMESTAMP_FIELD_NAME = "observable-id-timestamp"
private val METHOD_NAME_FIELD_NAME = "method-name"

View File

@ -0,0 +1,30 @@
package net.corda.nodeapi.internal
import com.google.common.cache.CacheBuilder
import com.google.common.cache.CacheLoader
import java.time.Duration
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong
/**
* A class allowing the deduplication of a strictly incrementing sequence number.
*/
class DeduplicationChecker(cacheExpiry: Duration) {
// dedupe identity -> watermark cache
private val watermarkCache = CacheBuilder.newBuilder()
.expireAfterAccess(cacheExpiry.toNanos(), TimeUnit.NANOSECONDS)
.build(WatermarkCacheLoader)
private object WatermarkCacheLoader : CacheLoader<Any, AtomicLong>() {
override fun load(key: Any) = AtomicLong(-1)
}
/**
* @param identity the identity that generates the sequence numbers.
* @param sequenceNumber the sequence number to check.
* @return true if the message is unique, false if it's a duplicate.
*/
fun checkDuplicateMessageId(identity: Any, sequenceNumber: Long): Boolean {
return watermarkCache[identity].getAndUpdate { maxOf(sequenceNumber, it) } >= sequenceNumber
}
}

View File

@ -43,7 +43,10 @@ enum class TransactionIsolationLevel {
}
private val _contextDatabase = ThreadLocal<CordaPersistence>()
val contextDatabase get() = _contextDatabase.get() ?: error("Was expecting to find CordaPersistence set on current thread: ${Strand.currentStrand()}")
var contextDatabase: CordaPersistence
get() = _contextDatabase.get() ?: error("Was expecting to find CordaPersistence set on current thread: ${Strand.currentStrand()}")
set(database) = _contextDatabase.set(database)
val contextDatabaseOrNull: CordaPersistence? get() = _contextDatabase.get()
class CordaPersistence(
val dataSource: DataSource,