diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt index 9857501ec3..a0a4a7ea1f 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt @@ -9,6 +9,7 @@ import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.config.SSLConfiguration +import net.corda.nodeapi.internal.serialization.AMQPClientSerializationScheme import net.corda.nodeapi.internal.serialization.KRYO_P2P_CONTEXT import net.corda.nodeapi.internal.serialization.KRYO_RPC_CLIENT_CONTEXT import net.corda.nodeapi.internal.serialization.SerializationFactoryImpl @@ -71,6 +72,7 @@ class CordaRPCClient( try { SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { registerScheme(KryoClientSerializationScheme()) + registerScheme(AMQPClientSerializationScheme()) } SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT SerializationDefaults.RPC_CLIENT_CONTEXT = KRYO_RPC_CLIENT_CONTEXT diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AMQPSerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AMQPSerializationScheme.kt new file mode 100644 index 0000000000..9772bef28a --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AMQPSerializationScheme.kt @@ -0,0 +1,94 @@ +package net.corda.nodeapi.internal.serialization + +import net.corda.core.serialization.ClassWhitelist +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SerializedBytes +import net.corda.core.utilities.ByteSequence +import net.corda.nodeapi.internal.serialization.amqp.AmqpHeaderV1_0 +import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput +import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput +import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory +import java.util.concurrent.ConcurrentHashMap + +private const val AMQP_ENABLED = false + +abstract class AbstractAMQPSerializationScheme : SerializationScheme { + private val serializerFactoriesForContexts = ConcurrentHashMap, SerializerFactory>() + + protected abstract fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory + protected abstract fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory + + private fun getSerializerFactory(context: SerializationContext): SerializerFactory { + return serializerFactoriesForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { + when (context.useCase) { + SerializationContext.UseCase.Checkpoint -> + throw IllegalStateException("AMQP should not be used for checkpoint serialization.") + SerializationContext.UseCase.RPCClient -> + rpcClientSerializerFactory(context) + SerializationContext.UseCase.RPCServer -> + rpcServerSerializerFactory(context) + else -> SerializerFactory(context.whitelist) // TODO pass class loader also + } + } + } + + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { + val serializerFactory = getSerializerFactory(context) + return DeserializationInput(serializerFactory).deserialize(byteSequence, clazz) + } + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + val serializerFactory = getSerializerFactory(context) + return SerializationOutput(serializerFactory).serialize(obj) + } + + protected fun canDeserializeVersion(byteSequence: ByteSequence): Boolean = AMQP_ENABLED && byteSequence == AmqpHeaderV1_0 +} + +// TODO: This will eventually cover server RPC as well and move to node module, but for now this is not implemented +class AMQPServerSerializationScheme : AbstractAMQPSerializationScheme() { + override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory { + throw UnsupportedOperationException() + } + + override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { + return (canDeserializeVersion(byteSequence) && + (target == SerializationContext.UseCase.P2P || target == SerializationContext.UseCase.Storage)) + } + +} + +// TODO: This will eventually cover client RPC as well and move to client module, but for now this is not implemented +class AMQPClientSerializationScheme : AbstractAMQPSerializationScheme() { + override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory { + throw UnsupportedOperationException() + } + + override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { + return (canDeserializeVersion(byteSequence) && + (target == SerializationContext.UseCase.P2P || target == SerializationContext.UseCase.Storage)) + } + +} + +val AMQP_P2P_CONTEXT = SerializationContextImpl(AmqpHeaderV1_0, + SerializationDefaults.javaClass.classLoader, + GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), + emptyMap(), + true, + SerializationContext.UseCase.P2P) +val AMQP_STORAGE_CONTEXT = SerializationContextImpl(AmqpHeaderV1_0, + SerializationDefaults.javaClass.classLoader, + AllButBlacklisted, + emptyMap(), + true, + SerializationContext.UseCase.Storage) \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt index cd68bbb1bd..688dfacf7e 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt @@ -54,6 +54,8 @@ data class SerializationContextImpl(override val preferedSerializationVersion: B } } +private const val HEADER_SIZE: Int = 8 + open class SerializationFactoryImpl : SerializationFactory { private val creator: List = Exception().stackTrace.asList() @@ -63,8 +65,8 @@ open class SerializationFactoryImpl : SerializationFactory { private val schemes: ConcurrentHashMap, SerializationScheme> = ConcurrentHashMap() private fun schemeFor(byteSequence: ByteSequence, target: SerializationContext.UseCase): SerializationScheme { - // truncate sequence to 8 bytes - return schemes.computeIfAbsent(byteSequence.take(8).copy() to target) { + // truncate sequence to 8 bytes, and make sure it's a copy to avoid holding onto large ByteArrays + return schemes.computeIfAbsent(byteSequence.take(HEADER_SIZE).copy() to target) { for (scheme in registeredSchemes) { if (scheme.canDeserializeVersion(it.first, it.second)) { return@computeIfAbsent scheme @@ -162,11 +164,12 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme { override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { val pool = getPool(context) - Input(byteSequence.bytes, byteSequence.offset, byteSequence.size).use { input -> - val header = OpaqueBytes(input.readBytes(8)) - if (header != KryoHeaderV0_1) { - throw KryoException("Serialized bytes header does not match expected format.") - } + val headerSize = KryoHeaderV0_1.size + val header = byteSequence.take(headerSize) + if (header != KryoHeaderV0_1) { + throw KryoException("Serialized bytes header does not match expected format.") + } + Input(byteSequence.bytes, byteSequence.offset + headerSize, byteSequence.size - headerSize).use { input -> return pool.run { kryo -> withContext(kryo, context) { @Suppress("UNCHECKED_CAST") diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt index 907d39ec83..5654dfa20d 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt @@ -2,6 +2,7 @@ package net.corda.nodeapi.internal.serialization.amqp import net.corda.core.internal.getStackTraceAsString import net.corda.core.serialization.SerializedBytes +import net.corda.core.utilities.ByteSequence import org.apache.qpid.proton.amqp.Binary import org.apache.qpid.proton.amqp.DescribedType import org.apache.qpid.proton.amqp.UnsignedByte @@ -26,17 +27,6 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory = S internal companion object { val BYTES_NEEDED_TO_PEEK: Int = 23 - private fun subArraysEqual(a: ByteArray, aOffset: Int, length: Int, b: ByteArray, bOffset: Int): Boolean { - if (aOffset + length > a.size || bOffset + length > b.size) throw IndexOutOfBoundsException() - var bytesRemaining = length - var aPos = aOffset - var bPos = bOffset - while (bytesRemaining-- > 0) { - if (a[aPos++] != b[bPos++]) return false - } - return true - } - fun peekSize(bytes: ByteArray): Int { // There's an 8 byte header, and then a 0 byte plus descriptor followed by constructor val eighth = bytes[8].toInt() @@ -69,15 +59,16 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory = S @Throws(NotSerializableException::class) - private fun getEnvelope(bytes: SerializedBytes): Envelope { + private fun getEnvelope(bytes: ByteSequence): Envelope { // Check that the lead bytes match expected header - if (!subArraysEqual(bytes.bytes, 0, 8, AmqpHeaderV1_0.bytes, 0)) { + val headerSize = AmqpHeaderV1_0.size + if (bytes.take(headerSize) != AmqpHeaderV1_0) { throw NotSerializableException("Serialization header does not match.") } val data = Data.Factory.create() - val size = data.decode(ByteBuffer.wrap(bytes.bytes, 8, bytes.size - 8)) - if (size.toInt() != bytes.size - 8) { + val size = data.decode(ByteBuffer.wrap(bytes.bytes, bytes.offset + headerSize, bytes.size - headerSize)) + if (size.toInt() != bytes.size - headerSize) { throw NotSerializableException("Unexpected size of data") } @@ -103,7 +94,7 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory = S * be deserialized and a schema describing the types of the objects. */ @Throws(NotSerializableException::class) - fun deserialize(bytes: SerializedBytes, clazz: Class): T { + fun deserialize(bytes: ByteSequence, clazz: Class): T { return des { val envelope = getEnvelope(bytes) clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz)) diff --git a/core/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaSerializationOutputTests.java b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/JavaSerializationOutputTests.java similarity index 100% rename from core/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaSerializationOutputTests.java rename to node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/JavaSerializationOutputTests.java diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index 29acbc417b..eea378de27 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -332,6 +332,7 @@ open class Node(override val configuration: FullNodeConfiguration, private fun initialiseSerialization() { SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) + registerScheme(AMQPServerSerializationScheme()) } SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT SerializationDefaults.RPC_SERVER_CONTEXT = KRYO_RPC_SERVER_CONTEXT diff --git a/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt b/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt index b6612af10e..82d7925fec 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt @@ -63,6 +63,8 @@ fun initialiseTestSerialization() { (SerializationDefaults.SERIALIZATION_FACTORY as TestSerializationFactory).delegate = SerializationFactoryImpl().apply { registerScheme(KryoClientSerializationScheme()) registerScheme(KryoServerSerializationScheme()) + registerScheme(AMQPClientSerializationScheme()) + registerScheme(AMQPServerSerializationScheme()) } (SerializationDefaults.P2P_CONTEXT as TestSerializationContext).delegate = KRYO_P2P_CONTEXT (SerializationDefaults.RPC_SERVER_CONTEXT as TestSerializationContext).delegate = KRYO_RPC_SERVER_CONTEXT