diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt index 35c1baccaf..8b45999554 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt @@ -31,26 +31,15 @@ import org.apache.activemq.artemis.api.core.ActiveMQException import org.apache.activemq.artemis.api.core.ActiveMQNotConnectedException import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.SimpleString +import org.apache.activemq.artemis.api.core.client.* import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE -import org.apache.activemq.artemis.api.core.client.ClientConsumer -import org.apache.activemq.artemis.api.core.client.ClientMessage -import org.apache.activemq.artemis.api.core.client.ClientProducer -import org.apache.activemq.artemis.api.core.client.ClientSession -import org.apache.activemq.artemis.api.core.client.ClientSessionFactory -import org.apache.activemq.artemis.api.core.client.FailoverEventType -import org.apache.activemq.artemis.api.core.client.ServerLocator import rx.Notification import rx.Observable import rx.subjects.UnicastSubject import java.lang.reflect.InvocationHandler import java.lang.reflect.Method import java.util.* -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.ExecutorService -import java.util.concurrent.Executors -import java.util.concurrent.ScheduledExecutorService -import java.util.concurrent.ScheduledFuture -import java.util.concurrent.TimeUnit +import java.util.concurrent.* import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicLong import kotlin.reflect.jvm.javaMethod @@ -288,56 +277,71 @@ class RPCClientProxyHandler( // The handler for Artemis messages. private fun artemisMessageHandler(message: ClientMessage) { - val serverToClient = RPCApi.ServerToClient.fromClientMessage(serializationContextWithObservableContext, message) - val deduplicationSequenceNumber = message.getLongProperty(RPCApi.DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME) - if (deduplicationChecker.checkDuplicateMessageId(serverToClient.deduplicationIdentity, deduplicationSequenceNumber)) { - log.info("Message duplication detected, discarding message") - return + fun completeExceptionally(id: InvocationId, e: Throwable, future: SettableFuture?) { + val rpcCallSite: Throwable? = callSiteMap?.get(id) + if (rpcCallSite != null) addRpcCallSiteToThrowable(e, rpcCallSite) + future?.setException(e.cause ?: e) } - log.debug { "Got message from RPC server $serverToClient" } - when (serverToClient) { - is RPCApi.ServerToClient.RpcReply -> { - val replyFuture = rpcReplyMap.remove(serverToClient.id) - if (replyFuture == null) { - log.error("RPC reply arrived to unknown RPC ID ${serverToClient.id}, this indicates an internal RPC error.") - } else { - val result = serverToClient.result - when (result) { - is Try.Success -> replyFuture.set(result.value) - is Try.Failure -> { - val rpcCallSite = callSiteMap?.get(serverToClient.id) - if (rpcCallSite != null) addRpcCallSiteToThrowable(result.exception, rpcCallSite) - replyFuture.setException(result.exception) - } - } - } - } - is RPCApi.ServerToClient.Observation -> { - val observable = observableContext.observableMap.getIfPresent(serverToClient.id) - if (observable == null) { - log.debug("Observation ${serverToClient.content} arrived to unknown Observable with ID ${serverToClient.id}. " + - "This may be due to an observation arriving before the server was " + - "notified of observable shutdown") - } else { - // We schedule the onNext() on an executor sticky-pooled based on the Observable ID. - observationExecutorPool.run(serverToClient.id) { executor -> - executor.submit { - val content = serverToClient.content - if (content.isOnCompleted || content.isOnError) { - observableContext.observableMap.invalidate(serverToClient.id) - } - // Add call site information on error - if (content.isOnError) { - val rpcCallSite = callSiteMap?.get(serverToClient.id) - if (rpcCallSite != null) addRpcCallSiteToThrowable(content.throwable, rpcCallSite) - } - observable.onNext(content) + + try { + // Deserialize the reply from the server, both the wrapping metadata and the actual body of the return value. + val serverToClient: RPCApi.ServerToClient = try { + RPCApi.ServerToClient.fromClientMessage(serializationContextWithObservableContext, message) + } catch (e: RPCApi.ServerToClient.FailedToDeserializeReply) { + // Might happen if something goes wrong during mapping the response to classes, evolution, class synthesis etc. + log.error("Failed to deserialize RPC body", e) + completeExceptionally(e.id, e, rpcReplyMap.remove(e.id)) + return + } + val deduplicationSequenceNumber = message.getLongProperty(RPCApi.DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME) + if (deduplicationChecker.checkDuplicateMessageId(serverToClient.deduplicationIdentity, deduplicationSequenceNumber)) { + log.info("Message duplication detected, discarding message") + return + } + log.debug { "Got message from RPC server $serverToClient" } + when (serverToClient) { + is RPCApi.ServerToClient.RpcReply -> { + val replyFuture = rpcReplyMap.remove(serverToClient.id) + if (replyFuture == null) { + log.error("RPC reply arrived to unknown RPC ID ${serverToClient.id}, this indicates an internal RPC error.") + } else { + val result: Try = serverToClient.result + when (result) { + is Try.Success -> replyFuture.set(result.value) + is Try.Failure -> { + completeExceptionally(serverToClient.id, result.exception, replyFuture) + } + } + } + } + is RPCApi.ServerToClient.Observation -> { + val observable: UnicastSubject>? = observableContext.observableMap.getIfPresent(serverToClient.id) + if (observable == null) { + log.debug("Observation ${serverToClient.content} arrived to unknown Observable with ID ${serverToClient.id}. " + + "This may be due to an observation arriving before the server was " + + "notified of observable shutdown") + } else { + // We schedule the onNext() on an executor sticky-pooled based on the Observable ID. + observationExecutorPool.run(serverToClient.id) { executor -> + executor.submit { + val content = serverToClient.content + if (content.isOnCompleted || content.isOnError) { + observableContext.observableMap.invalidate(serverToClient.id) + } + // Add call site information on error + if (content.isOnError) { + val rpcCallSite = callSiteMap?.get(serverToClient.id) + if (rpcCallSite != null) addRpcCallSiteToThrowable(content.throwable, rpcCallSite) + } + observable.onNext(content) + } } } } } + } finally { + message.acknowledge() } - message.acknowledge() } /** diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt b/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt index 1a482e1ff0..1f118aa631 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt @@ -14,7 +14,7 @@ import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.Try import org.apache.activemq.artemis.api.core.ActiveMQBuffer import org.apache.activemq.artemis.api.core.SimpleString -import org.apache.activemq.artemis.api.core.client.* +import org.apache.activemq.artemis.api.core.client.ClientMessage import org.apache.activemq.artemis.api.core.management.CoreNotificationType import org.apache.activemq.artemis.api.core.management.ManagementHelper import org.apache.activemq.artemis.reader.MessageUtil @@ -212,6 +212,11 @@ object RPCApi { } } + /** + * Thrown if the RPC reply body couldn't be deserialized. + */ + class FailedToDeserializeReply(val id: InvocationId, cause: Throwable) : RuntimeException("Failed to deserialize RPC reply: ${cause.message}", cause) + companion object { private fun Any.safeSerialize(context: SerializationContext, wrap: (Throwable) -> Any) = try { serialize(context = context) @@ -226,10 +231,18 @@ object RPCApi { RPCApi.ServerToClient.Tag.RPC_REPLY -> { val id = message.invocationId(RPC_ID_FIELD_NAME, RPC_ID_TIMESTAMP_FIELD_NAME) ?: throw IllegalStateException("Cannot parse invocation id from client message.") val poolWithIdContext = context.withProperty(RpcRequestOrObservableIdKey, id) + // The result here is a Try<> that represents the attempt to try the operation on the server side. + // If anything goes wrong with deserialisation of the response, we propagate it differently because + // we also need to pass through the invocation and dedupe IDs. + val result: Try = try { + message.getBodyAsByteArray().deserialize(context = poolWithIdContext) + } catch (e: Exception) { + throw FailedToDeserializeReply(id, e) + } RpcReply( id = id, deduplicationIdentity = deduplicationIdentity, - result = message.getBodyAsByteArray().deserialize(context = poolWithIdContext) + result = result ) } RPCApi.ServerToClient.Tag.OBSERVATION -> {