Delay RPC arguments deserialisation to allow routing of errors

This commit is contained in:
Andras Slemmer 2017-09-26 10:53:07 +01:00
parent 20a9892123
commit 9d115a2111
5 changed files with 36 additions and 60 deletions

View File

@ -7,6 +7,7 @@ import net.corda.core.internal.concurrent.fork
import net.corda.core.internal.concurrent.transpose import net.corda.core.internal.concurrent.transpose
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.serialize
import net.corda.core.utilities.* import net.corda.core.utilities.*
import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.RPCApi import net.corda.nodeapi.RPCApi
@ -315,9 +316,9 @@ class RPCStabilityTests {
clientAddress = SimpleString(myQueue), clientAddress = SimpleString(myQueue),
id = RPCApi.RpcRequestId(random63BitValue()), id = RPCApi.RpcRequestId(random63BitValue()),
methodName = SlowConsumerRPCOps::streamAtInterval.name, methodName = SlowConsumerRPCOps::streamAtInterval.name,
arguments = listOf(10.millis, 123456) serialisedArguments = listOf(10.millis, 123456).serialize(context = SerializationDefaults.RPC_SERVER_CONTEXT).bytes
) )
request.writeToClientMessage(SerializationDefaults.RPC_SERVER_CONTEXT, message) request.writeToClientMessage(message)
producer.send(message) producer.send(message)
session.commit() session.commit()

View File

@ -19,6 +19,7 @@ import net.corda.core.internal.LifeCycle
import net.corda.core.internal.ThreadBox import net.corda.core.internal.ThreadBox
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.serialize
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
@ -208,11 +209,12 @@ class RPCClientProxyHandler(
val rpcId = RPCApi.RpcRequestId(random63BitValue()) val rpcId = RPCApi.RpcRequestId(random63BitValue())
callSiteMap?.set(rpcId.toLong, Throwable("<Call site of root RPC '${method.name}'>")) callSiteMap?.set(rpcId.toLong, Throwable("<Call site of root RPC '${method.name}'>"))
try { try {
val request = RPCApi.ClientToServer.RpcRequest(clientAddress, rpcId, method.name, arguments?.toList() ?: emptyList()) val serialisedArguments = (arguments?.toList() ?: emptyList()).serialize(context = serializationContextWithObservableContext)
val request = RPCApi.ClientToServer.RpcRequest(clientAddress, rpcId, method.name, serialisedArguments.bytes)
val replyFuture = SettableFuture.create<Any>() val replyFuture = SettableFuture.create<Any>()
sessionAndProducerPool.run { sessionAndProducerPool.run {
val message = it.session.createMessage(false) val message = it.session.createMessage(false)
request.writeToClientMessage(serializationContextWithObservableContext, message) request.writeToClientMessage(message)
log.debug { log.debug {
val argumentsString = arguments?.joinToString() ?: "" val argumentsString = arguments?.joinToString() ?: ""

View File

@ -548,7 +548,6 @@ public class FlowCookbookJava {
// DOCSTART 37 // DOCSTART 37
twiceSignedTx.checkSignaturesAreValid(); twiceSignedTx.checkSignaturesAreValid();
// DOCEND 37 // DOCEND 37
} catch (GeneralSecurityException e) { } catch (GeneralSecurityException e) {
// Handle this as required. // Handle this as required.
} }

View File

@ -97,20 +97,20 @@ object RPCApi {
* @param clientAddress return address to contact the client at. * @param clientAddress return address to contact the client at.
* @param id a unique ID for the request, which the server will use to identify its response with. * @param id a unique ID for the request, which the server will use to identify its response with.
* @param methodName name of the method (procedure) to be called. * @param methodName name of the method (procedure) to be called.
* @param arguments arguments to pass to the method, if any. * @param serialisedArguments Serialised arguments to pass to the method, if any.
*/ */
data class RpcRequest( data class RpcRequest(
val clientAddress: SimpleString, val clientAddress: SimpleString,
val id: RpcRequestId, val id: RpcRequestId,
val methodName: String, val methodName: String,
val arguments: List<Any?> val serialisedArguments: ByteArray
) : ClientToServer() { ) : ClientToServer() {
fun writeToClientMessage(context: SerializationContext, message: ClientMessage) { fun writeToClientMessage(message: ClientMessage) {
MessageUtil.setJMSReplyTo(message, clientAddress) MessageUtil.setJMSReplyTo(message, clientAddress)
message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REQUEST.ordinal) message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REQUEST.ordinal)
message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong) message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong)
message.putStringProperty(METHOD_NAME_FIELD_NAME, methodName) message.putStringProperty(METHOD_NAME_FIELD_NAME, methodName)
message.bodyBuffer.writeBytes(arguments.serialize(context = context).bytes) message.bodyBuffer.writeBytes(serialisedArguments)
} }
} }
@ -128,26 +128,15 @@ object RPCApi {
} }
companion object { companion object {
fun fromClientMessage(message: ClientMessage): ClientToServer {
fun fromClientMessage(context: SerializationContext, message: ClientMessage): ClientToServer {
val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)] val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)]
return when (tag) { return when (tag) {
RPCApi.ClientToServer.Tag.RPC_REQUEST -> { RPCApi.ClientToServer.Tag.RPC_REQUEST -> RpcRequest(
val partialReq = RpcRequest(
clientAddress = MessageUtil.getJMSReplyTo(message), clientAddress = MessageUtil.getJMSReplyTo(message),
id = RpcRequestId(message.getLongProperty(RPC_ID_FIELD_NAME)), id = RpcRequestId(message.getLongProperty(RPC_ID_FIELD_NAME)),
methodName = message.getStringProperty(METHOD_NAME_FIELD_NAME), methodName = message.getStringProperty(METHOD_NAME_FIELD_NAME),
arguments = emptyList() serialisedArguments = message.getBodyAsByteArray()
) )
// Deserialisation of the arguments can fail, but we'd like to return a response mapped to
// this specific RPC, so we throw the partial request with ID and method name.
try {
val arguments = message.getBodyAsByteArray().deserialize<List<Any?>>(context = context)
return partialReq.copy(arguments = arguments)
} catch (t: Throwable) {
throw ArgumentDeserialisationException(t, partialReq)
}
}
RPCApi.ClientToServer.Tag.OBSERVABLES_CLOSED -> { RPCApi.ClientToServer.Tag.OBSERVABLES_CLOSED -> {
val ids = ArrayList<ObservableId>() val ids = ArrayList<ObservableId>()
val buffer = message.bodyBuffer val buffer = message.bodyBuffer
@ -160,7 +149,6 @@ object RPCApi {
} }
} }
} }
} }
/** /**
@ -227,13 +215,6 @@ object RPCApi {
} }
} }
} }
/**
* Thrown when the arguments passed to an RPC couldn't be deserialised by the server. This will
* typically indicate a missing application on the server side, or different versions between
* client and server.
*/
class ArgumentDeserialisationException(cause: Throwable, val reqWithNoArguments: ClientToServer.RpcRequest) : RPCException("Failed to deserialise RPC arguments: version or app skew between client and server?", cause)
} }
data class ArtemisProducer( data class ArtemisProducer(

View File

@ -19,6 +19,7 @@ import net.corda.core.internal.LifeCycle
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults.RPC_SERVER_CONTEXT import net.corda.core.serialization.SerializationDefaults.RPC_SERVER_CONTEXT
import net.corda.core.serialization.deserialize
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
@ -260,36 +261,28 @@ class RPCServer(
private fun clientArtemisMessageHandler(artemisMessage: ClientMessage) { private fun clientArtemisMessageHandler(artemisMessage: ClientMessage) {
lifeCycle.requireState(State.STARTED) lifeCycle.requireState(State.STARTED)
val clientToServer = RPCApi.ClientToServer.fromClientMessage(artemisMessage)
// Attempt de-serialisation of the RPC request message, in such a way that we can route the error back to
// the RPC that was being tried if it fails in a method/rpc specific way.
val clientToServerTry = Try.on { RPCApi.ClientToServer.fromClientMessage(RPC_SERVER_CONTEXT, artemisMessage) }
val clientToServer = try {
clientToServerTry.getOrThrow()
} catch (e: RPCApi.ArgumentDeserialisationException) {
// The exception itself has a more informative error message, and this could be caused by buggy apps, so
// let's just log it as a warning instead of an error. Relay the failure to the client so they can see it.
log.warn("Inbound RPC failed", e)
sendReply(e.reqWithNoArguments.id, e.reqWithNoArguments.clientAddress, Try.Failure(e.cause!!))
return
} catch (e: Exception) {
// This path indicates something more fundamental went wrong, like a missing message header.
log.error("Failed to parse an inbound RPC: version skew between client and server?", e)
return
} finally {
artemisMessage.acknowledge()
}
// Now try dispatching the request itself.
log.debug { "-> RPC -> $clientToServer" } log.debug { "-> RPC -> $clientToServer" }
when (clientToServer) { when (clientToServer) {
is RPCApi.ClientToServer.RpcRequest -> { is RPCApi.ClientToServer.RpcRequest -> {
val arguments = Try.on {
clientToServer.serialisedArguments.deserialize<List<Any?>>(context = RPC_SERVER_CONTEXT)
}
when (arguments) {
is Try.Success -> {
val rpcContext = RpcContext(currentUser = getUser(artemisMessage)) val rpcContext = RpcContext(currentUser = getUser(artemisMessage))
rpcExecutor!!.submit { rpcExecutor!!.submit {
val result = invokeRpc(rpcContext, clientToServer.methodName, clientToServer.arguments) val result = invokeRpc(rpcContext, clientToServer.methodName, arguments.value)
sendReply(clientToServer.id, clientToServer.clientAddress, result) sendReply(clientToServer.id, clientToServer.clientAddress, result)
} }
} }
is Try.Failure -> {
// We failed to deserialise the arguments, route back the error
log.warn("Inbound RPC failed", arguments.exception)
sendReply(clientToServer.id, clientToServer.clientAddress, arguments)
}
}
}
is RPCApi.ClientToServer.ObservablesClosed -> { is RPCApi.ClientToServer.ObservablesClosed -> {
observableMap.invalidateAll(clientToServer.ids) observableMap.invalidateAll(clientToServer.ids)
} }