CORDA-973 Compression support for serialization (#2473)

* Serialization magic is now 7 bytes
* Introduce encoding property and whitelist
This commit is contained in:
Andrzej Cichocki 2018-02-23 13:07:51 +00:00 committed by GitHub
parent 2af0feee04
commit c8672d373f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 436 additions and 90 deletions

View File

@ -2817,7 +2817,7 @@ public final class net.corda.core.serialization.ObjectWithCompatibleContext exte
public final class net.corda.core.serialization.SerializationAPIKt extends java.lang.Object public final class net.corda.core.serialization.SerializationAPIKt extends java.lang.Object
@org.jetbrains.annotations.NotNull public static final net.corda.core.serialization.SerializedBytes serialize(Object, net.corda.core.serialization.SerializationFactory, net.corda.core.serialization.SerializationContext) @org.jetbrains.annotations.NotNull public static final net.corda.core.serialization.SerializedBytes serialize(Object, net.corda.core.serialization.SerializationFactory, net.corda.core.serialization.SerializationContext)
## ##
public interface net.corda.core.serialization.SerializationContext @net.corda.core.DoNotImplement public interface net.corda.core.serialization.SerializationContext
@org.jetbrains.annotations.NotNull public abstract ClassLoader getDeserializationClassLoader() @org.jetbrains.annotations.NotNull public abstract ClassLoader getDeserializationClassLoader()
public abstract boolean getObjectReferencesEnabled() public abstract boolean getObjectReferencesEnabled()
@org.jetbrains.annotations.NotNull public abstract net.corda.core.utilities.ByteSequence getPreferredSerializationVersion() @org.jetbrains.annotations.NotNull public abstract net.corda.core.utilities.ByteSequence getPreferredSerializationVersion()

View File

@ -1,5 +1,6 @@
package net.corda.core.serialization package net.corda.core.serialization
import net.corda.core.DoNotImplement
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.sha256 import net.corda.core.crypto.sha256
import net.corda.core.serialization.internal.effectiveSerializationEnv import net.corda.core.serialization.internal.effectiveSerializationEnv
@ -99,14 +100,22 @@ abstract class SerializationFactory {
} }
} }
typealias SerializationMagic = ByteSequence typealias SerializationMagic = ByteSequence
@DoNotImplement
interface SerializationEncoding
/** /**
* Parameters to serialization and deserialization. * Parameters to serialization and deserialization.
*/ */
@DoNotImplement
interface SerializationContext { interface SerializationContext {
/** /**
* When serializing, use the format this header sequence represents. * When serializing, use the format this header sequence represents.
*/ */
val preferredSerializationVersion: SerializationMagic val preferredSerializationVersion: SerializationMagic
/**
* If non-null, apply this encoding (typically compression) when serializing.
*/
val encoding: SerializationEncoding?
/** /**
* The class loader to use for deserialization. * The class loader to use for deserialization.
*/ */
@ -115,6 +124,10 @@ interface SerializationContext {
* A whitelist that contains (mostly for security purposes) which classes can be serialized and deserialized. * A whitelist that contains (mostly for security purposes) which classes can be serialized and deserialized.
*/ */
val whitelist: ClassWhitelist val whitelist: ClassWhitelist
/**
* A whitelist that determines (mostly for security purposes) whether a particular encoding may be used when deserializing.
*/
val encodingWhitelist: EncodingWhitelist
/** /**
* A map of any addition properties specific to the particular use case. * A map of any addition properties specific to the particular use case.
*/ */
@ -161,6 +174,11 @@ interface SerializationContext {
*/ */
fun withPreferredSerializationVersion(magic: SerializationMagic): SerializationContext fun withPreferredSerializationVersion(magic: SerializationMagic): SerializationContext
/**
* A shallow copy of this context but with the given (possibly null) encoding.
*/
fun withEncoding(encoding: SerializationEncoding?): SerializationContext
/** /**
* The use case that we are serializing for, since it influences the implementations chosen. * The use case that we are serializing for, since it influences the implementations chosen.
*/ */
@ -232,3 +250,8 @@ class SerializedBytes<T : Any>(bytes: ByteArray) : OpaqueBytes(bytes) {
interface ClassWhitelist { interface ClassWhitelist {
fun hasListed(type: Class<*>): Boolean fun hasListed(type: Class<*>): Boolean
} }
@DoNotImplement
interface EncodingWhitelist {
fun acceptEncoding(encoding: SerializationEncoding): Boolean
}

View File

@ -34,6 +34,9 @@ dependencies {
// For AMQP serialisation. // For AMQP serialisation.
compile "org.apache.qpid:proton-j:0.21.0" compile "org.apache.qpid:proton-j:0.21.0"
// Pure-Java Snappy compression
compile 'org.iq80.snappy:snappy:0.4'
// Unit testing helpers. // Unit testing helpers.
testCompile "junit:junit:$junit_version" testCompile "junit:junit:$junit_version"
testCompile "org.assertj:assertj-core:$assertj_version" testCompile "org.assertj:assertj-core:$assertj_version"

View File

@ -18,10 +18,12 @@ val KRYO_RPC_CLIENT_CONTEXT = SerializationContextImpl(kryoMagic,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.RPCClient) SerializationContext.UseCase.RPCClient,
null)
val AMQP_RPC_CLIENT_CONTEXT = SerializationContextImpl(amqpMagic, val AMQP_RPC_CLIENT_CONTEXT = SerializationContextImpl(amqpMagic,
SerializationDefaults.javaClass.classLoader, SerializationDefaults.javaClass.classLoader,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.RPCClient) SerializationContext.UseCase.RPCClient,
null)

View File

@ -0,0 +1,31 @@
package net.corda.nodeapi.internal.serialization
import java.io.EOFException
import java.io.InputStream
import java.io.OutputStream
import java.nio.ByteBuffer
class OrdinalBits(private val ordinal: Int) {
interface OrdinalWriter {
val bits: OrdinalBits
val encodedSize get() = 1
fun writeTo(stream: OutputStream) = stream.write(bits.ordinal)
fun putTo(buffer: ByteBuffer) = buffer.put(bits.ordinal.toByte())!!
}
init {
require(ordinal >= 0) { "The ordinal must be non-negative." }
require(ordinal < 128) { "Consider implementing a varint encoding." }
}
}
class OrdinalReader<out E : Any>(private val values: Array<E>) {
private val enumName = values[0].javaClass.simpleName
private val range = 0 until values.size
fun readFrom(stream: InputStream): E {
val ordinal = stream.read()
if (ordinal == -1) throw EOFException("Expected a $enumName ordinal.")
if (ordinal !in range) throw NoSuchElementException("No $enumName with ordinal: $ordinal")
return values[ordinal]
}
}

View File

@ -1,8 +1,17 @@
package net.corda.nodeapi.internal.serialization package net.corda.nodeapi.internal.serialization
import net.corda.core.internal.VisibleForTesting
import net.corda.core.serialization.SerializationEncoding
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.nodeapi.internal.serialization.OrdinalBits.OrdinalWriter
import org.iq80.snappy.SnappyFramedInputStream
import org.iq80.snappy.SnappyFramedOutputStream
import java.io.OutputStream
import java.io.InputStream
import java.nio.ByteBuffer import java.nio.ByteBuffer
import java.util.zip.DeflaterOutputStream
import java.util.zip.InflaterInputStream
class CordaSerializationMagic(bytes: ByteArray) : OpaqueBytes(bytes) { class CordaSerializationMagic(bytes: ByteArray) : OpaqueBytes(bytes) {
private val bufferView = slice() private val bufferView = slice()
@ -10,3 +19,40 @@ class CordaSerializationMagic(bytes: ByteArray) : OpaqueBytes(bytes) {
return if (data.slice(end = size) == bufferView) data.slice(size) else null return if (data.slice(end = size) == bufferView) data.slice(size) else null
} }
} }
enum class SectionId : OrdinalWriter {
/** Serialization data follows, and then discard the rest of the stream (if any) as legacy data may have trailing garbage. */
DATA_AND_STOP,
/** Identical behaviour to [DATA_AND_STOP], historically used for Kryo. Do not use in new code. */
ALT_DATA_AND_STOP,
/** The ordinal of a [CordaSerializationEncoding] follows, which should be used to decode the remainder of the stream. */
ENCODING;
companion object {
val reader = OrdinalReader(values())
}
override val bits = OrdinalBits(ordinal)
}
enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter {
DEFLATE {
override fun wrap(stream: OutputStream) = DeflaterOutputStream(stream)
override fun wrap(stream: InputStream) = InflaterInputStream(stream)
},
SNAPPY {
override fun wrap(stream: OutputStream) = SnappyFramedOutputStream(stream)
override fun wrap(stream: InputStream) = SnappyFramedInputStream(stream, false)
};
companion object {
val reader = OrdinalReader(values())
}
override val bits = OrdinalBits(ordinal)
abstract fun wrap(stream: OutputStream): OutputStream
abstract fun wrap(stream: InputStream): InputStream
}
@VisibleForTesting
internal val encodingNotPermittedFormat = "Encoding not permitted: %s"

View File

@ -18,13 +18,18 @@ import java.util.concurrent.ExecutionException
val attachmentsClassLoaderEnabledPropertyName = "attachments.class.loader.enabled" val attachmentsClassLoaderEnabledPropertyName = "attachments.class.loader.enabled"
data class SerializationContextImpl(override val preferredSerializationVersion: SerializationMagic, internal object NullEncodingWhitelist : EncodingWhitelist {
override val deserializationClassLoader: ClassLoader, override fun acceptEncoding(encoding: SerializationEncoding) = false
override val whitelist: ClassWhitelist, }
override val properties: Map<Any, Any>,
override val objectReferencesEnabled: Boolean,
override val useCase: SerializationContext.UseCase) : SerializationContext {
data class SerializationContextImpl @JvmOverloads constructor(override val preferredSerializationVersion: SerializationMagic,
override val deserializationClassLoader: ClassLoader,
override val whitelist: ClassWhitelist,
override val properties: Map<Any, Any>,
override val objectReferencesEnabled: Boolean,
override val useCase: SerializationContext.UseCase,
override val encoding: SerializationEncoding?,
override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) : SerializationContext {
private val cache: Cache<List<SecureHash>, AttachmentsClassLoader> = CacheBuilder.newBuilder().weakValues().maximumSize(1024).build() private val cache: Cache<List<SecureHash>, AttachmentsClassLoader> = CacheBuilder.newBuilder().weakValues().maximumSize(1024).build()
/** /**
@ -70,6 +75,7 @@ data class SerializationContextImpl(override val preferredSerializationVersion:
} }
override fun withPreferredSerializationVersion(magic: SerializationMagic) = copy(preferredSerializationVersion = magic) override fun withPreferredSerializationVersion(magic: SerializationMagic) = copy(preferredSerializationVersion = magic)
override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding)
} }
open class SerializationFactoryImpl : SerializationFactory() { open class SerializationFactoryImpl : SerializationFactory() {

View File

@ -27,22 +27,26 @@ val KRYO_RPC_SERVER_CONTEXT = SerializationContextImpl(kryoMagic,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.RPCServer) SerializationContext.UseCase.RPCServer,
null)
val KRYO_STORAGE_CONTEXT = SerializationContextImpl(kryoMagic, val KRYO_STORAGE_CONTEXT = SerializationContextImpl(kryoMagic,
SerializationDefaults.javaClass.classLoader, SerializationDefaults.javaClass.classLoader,
AllButBlacklisted, AllButBlacklisted,
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.Storage) SerializationContext.UseCase.Storage,
null)
val AMQP_STORAGE_CONTEXT = SerializationContextImpl(amqpMagic, val AMQP_STORAGE_CONTEXT = SerializationContextImpl(amqpMagic,
SerializationDefaults.javaClass.classLoader, SerializationDefaults.javaClass.classLoader,
AllButBlacklisted, AllButBlacklisted,
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.Storage) SerializationContext.UseCase.Storage,
null)
val AMQP_RPC_SERVER_CONTEXT = SerializationContextImpl(amqpMagic, val AMQP_RPC_SERVER_CONTEXT = SerializationContextImpl(amqpMagic,
SerializationDefaults.javaClass.classLoader, SerializationDefaults.javaClass.classLoader,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.RPCServer) SerializationContext.UseCase.RPCServer,
null)

View File

@ -20,18 +20,19 @@ val KRYO_P2P_CONTEXT = SerializationContextImpl(kryoMagic,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.P2P) SerializationContext.UseCase.P2P,
null)
val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl(kryoMagic, val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl(kryoMagic,
SerializationDefaults.javaClass.classLoader, SerializationDefaults.javaClass.classLoader,
QuasarWhitelist, QuasarWhitelist,
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.Checkpoint) SerializationContext.UseCase.Checkpoint,
null)
val AMQP_P2P_CONTEXT = SerializationContextImpl(amqpMagic, val AMQP_P2P_CONTEXT = SerializationContextImpl(amqpMagic,
SerializationDefaults.javaClass.classLoader, SerializationDefaults.javaClass.classLoader,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.P2P) SerializationContext.UseCase.P2P,
null)

View File

@ -0,0 +1,31 @@
package net.corda.nodeapi.internal.serialization.amqp
import com.esotericsoftware.kryo.io.ByteBufferInputStream
import net.corda.nodeapi.internal.serialization.kryo.ByteBufferOutputStream
import net.corda.nodeapi.internal.serialization.kryo.serializeOutputStreamPool
import java.io.InputStream
import java.io.OutputStream
import java.nio.ByteBuffer
fun InputStream.asByteBuffer(): ByteBuffer {
return if (this is ByteBufferInputStream) {
byteBuffer // BBIS has no other state, so this is perfectly safe.
} else {
ByteBuffer.wrap(serializeOutputStreamPool.run {
copyTo(it)
it.toByteArray()
})
}
}
fun <T> OutputStream.alsoAsByteBuffer(remaining: Int, task: (ByteBuffer) -> T): T {
return if (this is ByteBufferOutputStream) {
alsoAsByteBuffer(remaining, task)
} else {
serializeOutputStreamPool.run {
val result = it.alsoAsByteBuffer(remaining, task)
it.copyTo(this)
result
}
}
}

View File

@ -1,18 +1,27 @@
package net.corda.nodeapi.internal.serialization.amqp package net.corda.nodeapi.internal.serialization.amqp
import com.esotericsoftware.kryo.io.ByteBufferInputStream
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.getStackTraceAsString import net.corda.core.internal.getStackTraceAsString
import net.corda.core.serialization.EncodingWhitelist
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.internal.serialization.CordaSerializationEncoding
import net.corda.nodeapi.internal.serialization.NullEncodingWhitelist
import net.corda.nodeapi.internal.serialization.SectionId
import net.corda.nodeapi.internal.serialization.encodingNotPermittedFormat
import org.apache.qpid.proton.amqp.Binary import org.apache.qpid.proton.amqp.Binary
import org.apache.qpid.proton.amqp.DescribedType import org.apache.qpid.proton.amqp.DescribedType
import org.apache.qpid.proton.amqp.UnsignedByte import org.apache.qpid.proton.amqp.UnsignedByte
import org.apache.qpid.proton.amqp.UnsignedInteger import org.apache.qpid.proton.amqp.UnsignedInteger
import org.apache.qpid.proton.codec.Data import org.apache.qpid.proton.codec.Data
import java.io.InputStream
import java.io.NotSerializableException import java.io.NotSerializableException
import java.lang.reflect.ParameterizedType import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type import java.lang.reflect.Type
import java.lang.reflect.TypeVariable import java.lang.reflect.TypeVariable
import java.lang.reflect.WildcardType import java.lang.reflect.WildcardType
import java.nio.ByteBuffer
data class ObjectAndEnvelope<out T>(val obj: T, val envelope: Envelope) data class ObjectAndEnvelope<out T>(val obj: T, val envelope: Envelope)
@ -22,7 +31,8 @@ data class ObjectAndEnvelope<out T>(val obj: T, val envelope: Envelope)
* @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple * @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple
* instances and threads. * instances and threads.
*/ */
class DeserializationInput(internal val serializerFactory: SerializerFactory) { class DeserializationInput @JvmOverloads constructor(private val serializerFactory: SerializerFactory,
private val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) {
private val objectHistory: MutableList<Any> = mutableListOf() private val objectHistory: MutableList<Any> = mutableListOf()
internal companion object { internal companion object {
@ -47,6 +57,28 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) {
} }
return size + BYTES_NEEDED_TO_PEEK return size + BYTES_NEEDED_TO_PEEK
} }
@VisibleForTesting
@Throws(NotSerializableException::class)
internal fun <T> withDataBytes(byteSequence: ByteSequence, encodingWhitelist: EncodingWhitelist, task: (ByteBuffer) -> T): T {
// Check that the lead bytes match expected header
val amqpSequence = amqpMagic.consume(byteSequence) ?: throw NotSerializableException("Serialization header does not match.")
var stream: InputStream = ByteBufferInputStream(amqpSequence)
try {
while (true) {
when (SectionId.reader.readFrom(stream)) {
SectionId.ENCODING -> {
val encoding = CordaSerializationEncoding.reader.readFrom(stream)
encodingWhitelist.acceptEncoding(encoding) || throw NotSerializableException(encodingNotPermittedFormat.format(encoding))
stream = encoding.wrap(stream)
}
SectionId.DATA_AND_STOP, SectionId.ALT_DATA_AND_STOP -> return task(stream.asByteBuffer())
}
}
} finally {
stream.close()
}
}
} }
@Throws(NotSerializableException::class) @Throws(NotSerializableException::class)
@ -58,12 +90,12 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) {
@Throws(NotSerializableException::class) @Throws(NotSerializableException::class)
internal fun getEnvelope(byteSequence: ByteSequence): Envelope { internal fun getEnvelope(byteSequence: ByteSequence): Envelope {
// Check that the lead bytes match expected header return withDataBytes(byteSequence, encodingWhitelist) { dataBytes ->
val dataBytes = amqpMagic.consume(byteSequence) ?: throw NotSerializableException("Serialization header does not match.") val data = Data.Factory.create()
val data = Data.Factory.create() val expectedSize = dataBytes.remaining()
val expectedSize = dataBytes.remaining() if (data.decode(dataBytes) != expectedSize.toLong()) throw NotSerializableException("Unexpected size of data")
if (data.decode(dataBytes) != expectedSize.toLong()) throw NotSerializableException("Unexpected size of data") Envelope.get(data)
return Envelope.get(data) }
} }
@Throws(NotSerializableException::class) @Throws(NotSerializableException::class)

View File

@ -12,7 +12,7 @@ import net.corda.nodeapi.internal.serialization.carpenter.Field as CarpenterFiel
import net.corda.nodeapi.internal.serialization.carpenter.Schema as CarpenterSchema import net.corda.nodeapi.internal.serialization.carpenter.Schema as CarpenterSchema
const val DESCRIPTOR_DOMAIN: String = "net.corda" const val DESCRIPTOR_DOMAIN: String = "net.corda"
val amqpMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(1, 0, 0)) val amqpMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(1, 0))
/** /**
* This and the classes below are OO representations of the AMQP XML schema described in the specification. Their * This and the classes below are OO representations of the AMQP XML schema described in the specification. Their

View File

@ -1,10 +1,14 @@
package net.corda.nodeapi.internal.serialization.amqp package net.corda.nodeapi.internal.serialization.amqp
import net.corda.core.serialization.SerializationEncoding
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.nodeapi.internal.serialization.CordaSerializationEncoding
import net.corda.nodeapi.internal.serialization.SectionId
import net.corda.nodeapi.internal.serialization.kryo.byteArrayOutput
import org.apache.qpid.proton.codec.Data import org.apache.qpid.proton.codec.Data
import java.io.NotSerializableException import java.io.NotSerializableException
import java.io.OutputStream
import java.lang.reflect.Type import java.lang.reflect.Type
import java.nio.ByteBuffer
import java.util.* import java.util.*
import kotlin.collections.LinkedHashSet import kotlin.collections.LinkedHashSet
@ -19,8 +23,7 @@ data class BytesAndSchemas<T : Any>(
* @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple * @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple
* instances and threads. * instances and threads.
*/ */
open class SerializationOutput(internal val serializerFactory: SerializerFactory) { open class SerializationOutput @JvmOverloads constructor(internal val serializerFactory: SerializerFactory, private val encoding: SerializationEncoding? = null) {
private val objectHistory: MutableMap<Any, Int> = IdentityHashMap() private val objectHistory: MutableMap<Any, Int> = IdentityHashMap()
private val serializerHistory: MutableSet<AMQPSerializer<*>> = LinkedHashSet() private val serializerHistory: MutableSet<AMQPSerializer<*>> = LinkedHashSet()
internal val schemaHistory: MutableSet<TypeNotation> = LinkedHashSet() internal val schemaHistory: MutableSet<TypeNotation> = LinkedHashSet()
@ -67,11 +70,21 @@ open class SerializationOutput(internal val serializerFactory: SerializerFactory
writeTransformSchema(TransformsSchema.build(schema, serializerFactory), this) writeTransformSchema(TransformsSchema.build(schema, serializerFactory), this)
} }
} }
val bytes = ByteArray(data.encodedSize().toInt() + 8) return SerializedBytes(byteArrayOutput {
val buf = ByteBuffer.wrap(bytes) var stream: OutputStream = it
amqpMagic.putTo(buf) try {
data.encode(buf) amqpMagic.writeTo(stream)
return SerializedBytes(bytes) if (encoding != null) {
SectionId.ENCODING.writeTo(stream)
(encoding as CordaSerializationEncoding).writeTo(stream)
stream = encoding.wrap(stream)
}
SectionId.DATA_AND_STOP.writeTo(stream)
stream.alsoAsByteBuffer(data.encodedSize().toInt(), data::encode)
} finally {
stream.close()
}
})
} }
internal fun writeObject(obj: Any, data: Data) { internal fun writeObject(obj: Any, data: Data) {

View File

@ -16,10 +16,12 @@ import net.corda.core.utilities.ByteSequence
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.nodeapi.internal.serialization.CordaSerializationMagic import net.corda.nodeapi.internal.serialization.CordaSerializationMagic
import net.corda.nodeapi.internal.serialization.CordaClassResolver import net.corda.nodeapi.internal.serialization.CordaClassResolver
import net.corda.nodeapi.internal.serialization.SectionId
import net.corda.nodeapi.internal.serialization.SerializationScheme import net.corda.nodeapi.internal.serialization.SerializationScheme
import net.corda.nodeapi.internal.serialization.*
import java.security.PublicKey import java.security.PublicKey
val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0, 1)) val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0))
private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>() { private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>() {
override fun write(kryo: Kryo, output: Output, closeable: AutoCloseable) { override fun write(kryo: Kryo, output: Output, closeable: AutoCloseable) {
@ -87,11 +89,25 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
val dataBytes = kryoMagic.consume(byteSequence) ?: throw KryoException("Serialized bytes header does not match expected format.") val dataBytes = kryoMagic.consume(byteSequence) ?: throw KryoException("Serialized bytes header does not match expected format.")
return context.kryo { return context.kryo {
kryoInput(ByteBufferInputStream(dataBytes)) { kryoInput(ByteBufferInputStream(dataBytes)) {
if (context.objectReferencesEnabled) { val result: T
uncheckedCast(readClassAndObject(this)) loop@ while (true) {
} else { when (SectionId.reader.readFrom(this)) {
withoutReferences { uncheckedCast<Any?, T>(readClassAndObject(this)) } SectionId.ENCODING -> {
val encoding = CordaSerializationEncoding.reader.readFrom(this)
context.encodingWhitelist.acceptEncoding(encoding) || throw KryoException(encodingNotPermittedFormat.format(encoding))
substitute(encoding::wrap)
}
SectionId.DATA_AND_STOP, SectionId.ALT_DATA_AND_STOP -> {
result = if (context.objectReferencesEnabled) {
uncheckedCast(readClassAndObject(this))
} else {
withoutReferences { uncheckedCast<Any?, T>(readClassAndObject(this)) }
}
break@loop
}
}
} }
result
} }
} }
} }
@ -100,6 +116,12 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
return context.kryo { return context.kryo {
SerializedBytes(kryoOutput { SerializedBytes(kryoOutput {
kryoMagic.writeTo(this) kryoMagic.writeTo(this)
context.encoding?.let { encoding ->
SectionId.ENCODING.writeTo(this)
(encoding as CordaSerializationEncoding).writeTo(this)
substitute(encoding::wrap)
}
SectionId.ALT_DATA_AND_STOP.writeTo(this) // Forward-compatible in null-encoding case.
if (context.objectReferencesEnabled) { if (context.objectReferencesEnabled) {
writeClassAndObject(this, obj) writeClassAndObject(this, obj)
} else { } else {

View File

@ -4,13 +4,34 @@ import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import net.corda.core.internal.LazyPool import net.corda.core.internal.LazyPool
import java.io.* import java.io.*
import java.nio.ByteBuffer
class ByteBufferOutputStream(size: Int) : ByteArrayOutputStream(size) {
companion object {
private val ensureCapacity = ByteArrayOutputStream::class.java.getDeclaredMethod("ensureCapacity", Int::class.java).apply {
isAccessible = true
}
}
fun <T> alsoAsByteBuffer(remaining: Int, task: (ByteBuffer) -> T): T {
ensureCapacity.invoke(this, count + remaining)
val buffer = ByteBuffer.wrap(buf, count, remaining)
val result = task(buffer)
count = buffer.position()
return result
}
fun copyTo(stream: OutputStream) {
stream.write(buf, 0, count)
}
}
private val serializationBufferPool = LazyPool( private val serializationBufferPool = LazyPool(
newInstance = { ByteArray(64 * 1024) }) newInstance = { ByteArray(64 * 1024) })
private val serializeOutputStreamPool = LazyPool( internal val serializeOutputStreamPool = LazyPool(
clear = ByteArrayOutputStream::reset, clear = ByteBufferOutputStream::reset,
shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large
newInstance = { ByteArrayOutputStream(64 * 1024) }) newInstance = { ByteBufferOutputStream(64 * 1024) })
internal fun <T> kryoInput(underlying: InputStream, task: Input.() -> T): T { internal fun <T> kryoInput(underlying: InputStream, task: Input.() -> T): T {
return serializationBufferPool.run { return serializationBufferPool.run {
@ -22,13 +43,19 @@ internal fun <T> kryoInput(underlying: InputStream, task: Input.() -> T): T {
} }
internal fun <T> kryoOutput(task: Output.() -> T): ByteArray { internal fun <T> kryoOutput(task: Output.() -> T): ByteArray {
return serializeOutputStreamPool.run { underlying -> return byteArrayOutput { underlying ->
serializationBufferPool.run { serializationBufferPool.run {
Output(it).use { output -> Output(it).use { output ->
output.outputStream = underlying output.outputStream = underlying
output.task() output.task()
} }
} }
}
}
internal fun <T> byteArrayOutput(task: (ByteBufferOutputStream) -> T): ByteArray {
return serializeOutputStreamPool.run { underlying ->
task(underlying)
underlying.toByteArray() // Must happen after close, to allow ZIP footer to be written for example. underlying.toByteArray() // Must happen after close, to allow ZIP footer to be written for example.
} }
} }

View File

@ -33,7 +33,7 @@ public final class ForbiddenLambdaSerializationTests {
EnumSet<SerializationContext.UseCase> contexts = EnumSet.complementOf(EnumSet.of(SerializationContext.UseCase.Checkpoint)); EnumSet<SerializationContext.UseCase> contexts = EnumSet.complementOf(EnumSet.of(SerializationContext.UseCase.Checkpoint));
contexts.forEach(ctx -> { contexts.forEach(ctx -> {
SerializationContext context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx); SerializationContext context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx, null);
String value = "Hey"; String value = "Hey";
Callable<String> target = (Callable<String> & Serializable) () -> value; Callable<String> target = (Callable<String> & Serializable) () -> value;
@ -55,7 +55,7 @@ public final class ForbiddenLambdaSerializationTests {
EnumSet<SerializationContext.UseCase> contexts = EnumSet.complementOf(EnumSet.of(SerializationContext.UseCase.Checkpoint)); EnumSet<SerializationContext.UseCase> contexts = EnumSet.complementOf(EnumSet.of(SerializationContext.UseCase.Checkpoint));
contexts.forEach(ctx -> { contexts.forEach(ctx -> {
SerializationContext context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx); SerializationContext context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx, null);
String value = "Hey"; String value = "Hey";
Callable<String> target = () -> value; Callable<String> target = () -> value;

View File

@ -26,7 +26,7 @@ public final class LambdaCheckpointSerializationTest {
@Before @Before
public void setup() { public void setup() {
factory = testSerialization.getSerializationFactory(); factory = testSerialization.getSerializationFactory();
context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, SerializationContext.UseCase.Checkpoint); context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, SerializationContext.UseCase.Checkpoint, null);
} }
@Test @Test

View File

@ -319,7 +319,8 @@ class X509UtilitiesTest {
AllWhitelist, AllWhitelist,
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.P2P) SerializationContext.UseCase.P2P,
null)
val expected = X509Utilities.createSelfSignedCACertificate(ALICE.name.x500Principal, Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)) val expected = X509Utilities.createSelfSignedCACertificate(ALICE.name.x500Principal, Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME))
val serialized = expected.serialize(factory, context).bytes val serialized = expected.serialize(factory, context).bytes
val actual = serialized.deserialize<X509Certificate>(factory, context) val actual = serialized.deserialize<X509Certificate>(factory, context)
@ -334,7 +335,8 @@ class X509UtilitiesTest {
AllWhitelist, AllWhitelist,
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.P2P) SerializationContext.UseCase.P2P,
null)
val rootCAKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) val rootCAKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)
val rootCACert = X509Utilities.createSelfSignedCACertificate(ALICE_NAME.x500Principal, rootCAKey) val rootCACert = X509Utilities.createSelfSignedCACertificate(ALICE_NAME.x500Principal, rootCAKey)
val certificate = X509Utilities.createCertificate(CertificateType.TLS, rootCACert, rootCAKey, BOB_NAME.x500Principal, BOB.publicKey) val certificate = X509Utilities.createCertificate(CertificateType.TLS, rootCACert, rootCAKey, BOB_NAME.x500Principal, BOB.publicKey)

View File

@ -108,8 +108,8 @@ class CordaClassResolverTests {
val emptyMapClass = mapOf<Any, Any>().javaClass val emptyMapClass = mapOf<Any, Any>().javaClass
} }
private val emptyWhitelistContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P) private val emptyWhitelistContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P, null)
private val allButBlacklistedContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P) private val allButBlacklistedContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P, null)
@Test @Test
fun `Annotation on enum works for specialised entries`() { fun `Annotation on enum works for specialised entries`() {
CordaClassResolver(emptyWhitelistContext).getRegistration(Foo.Bar::class.java) CordaClassResolver(emptyWhitelistContext).getRegistration(Foo.Bar::class.java)

View File

@ -1,10 +1,13 @@
package net.corda.nodeapi.internal.serialization package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.KryoSerializable import com.esotericsoftware.kryo.KryoSerializable
import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import com.google.common.primitives.Ints import com.google.common.primitives.Ints
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.contracts.PrivacySalt import net.corda.core.contracts.PrivacySalt
import net.corda.core.crypto.* import net.corda.core.crypto.*
import net.corda.core.internal.FetchDataFlow import net.corda.core.internal.FetchDataFlow
@ -16,24 +19,29 @@ import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.nodeapi.internal.serialization.kryo.kryoMagic import net.corda.nodeapi.internal.serialization.kryo.kryoMagic
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.TestIdentity import net.corda.testing.core.TestIdentity
import org.assertj.core.api.Assertions.assertThat import net.corda.testing.internal.rigorousMock
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.*
import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
import java.io.InputStream import java.io.InputStream
import java.time.Instant import java.time.Instant
import java.util.* import java.util.*
import kotlin.test.assertEquals import kotlin.test.*
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
class KryoTests { @RunWith(Parameterized::class)
class KryoTests(private val compression: CordaSerializationEncoding?) {
companion object { companion object {
private val ALICE_PUBKEY = TestIdentity(ALICE_NAME, 70).publicKey private val ALICE_PUBKEY = TestIdentity(ALICE_NAME, 70).publicKey
@Parameters(name = "{0}")
@JvmStatic
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
} }
private lateinit var factory: SerializationFactory private lateinit var factory: SerializationFactory
@ -47,7 +55,11 @@ class KryoTests {
AllWhitelist, AllWhitelist,
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.Storage) SerializationContext.UseCase.Storage,
compression,
rigorousMock<EncodingWhitelist>().also {
if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression)
})
} }
@Test @Test
@ -259,7 +271,8 @@ class KryoTests {
AllWhitelist, AllWhitelist,
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.P2P) SerializationContext.UseCase.P2P,
null)
pt.serialize(factory, context) pt.serialize(factory, context)
} }
@ -300,4 +313,24 @@ class KryoTests {
val exception2 = exception.serialize(factory, context).deserialize(factory, context) val exception2 = exception.serialize(factory, context).deserialize(factory, context)
assertEquals(randomHash, exception2.requested) assertEquals(randomHash, exception2.requested)
} }
@Test
fun `compression has the desired effect`() {
compression ?: return
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
val compressed = data.serialize(factory, context)
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
assertArrayEquals(data, compressed.deserialize(factory, context))
}
@Test
fun `a particular encoding can be banned for deserialization`() {
compression ?: return
doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression)
val compressed = "whatever".serialize(factory, context)
catchThrowable { compressed.deserialize(factory, context) }.run {
assertSame<Any>(KryoException::class.java, javaClass)
assertEquals(encodingNotPermittedFormat.format(compression), message)
}
}
} }

View File

@ -69,6 +69,7 @@ class ListsSerializationTest {
val serializedForm = emptyList<Int>().serialize() val serializedForm = emptyList<Int>().serialize()
val output = ByteArrayOutputStream().apply { val output = ByteArrayOutputStream().apply {
kryoMagic.writeTo(this) kryoMagic.writeTo(this)
SectionId.ALT_DATA_AND_STOP.writeTo(this)
write(DefaultClassResolver.NAME + 2) write(DefaultClassResolver.NAME + 2)
write(nameID) write(nameID)
write(javaEmptyListClass.name.toAscii()) write(javaEmptyListClass.name.toAscii())

View File

@ -79,6 +79,7 @@ class MapsSerializationTest {
val serializedForm = emptyMap<Int, Int>().serialize() val serializedForm = emptyMap<Int, Int>().serialize()
val output = ByteArrayOutputStream().apply { val output = ByteArrayOutputStream().apply {
kryoMagic.writeTo(this) kryoMagic.writeTo(this)
SectionId.ALT_DATA_AND_STOP.writeTo(this)
write(DefaultClassResolver.NAME + 2) write(DefaultClassResolver.NAME + 2)
write(nameID) write(nameID)
write(javaEmptyMapClass.name.toAscii()) write(javaEmptyMapClass.name.toAscii())

View File

@ -99,6 +99,7 @@ class SerializationTokenTest {
val stream = ByteArrayOutputStream() val stream = ByteArrayOutputStream()
Output(stream).use { Output(stream).use {
kryoMagic.writeTo(it) kryoMagic.writeTo(it)
SectionId.ALT_DATA_AND_STOP.writeTo(it)
kryo.writeClass(it, SingletonSerializeAsToken::class.java) kryo.writeClass(it, SingletonSerializeAsToken::class.java)
kryo.writeObject(it, emptyList<Any>()) kryo.writeObject(it, emptyList<Any>())
} }

View File

@ -56,6 +56,7 @@ class SetsSerializationTest {
val serializedForm = emptySet<Int>().serialize() val serializedForm = emptySet<Int>().serialize()
val output = ByteArrayOutputStream().apply { val output = ByteArrayOutputStream().apply {
kryoMagic.writeTo(this) kryoMagic.writeTo(this)
SectionId.ALT_DATA_AND_STOP.writeTo(this)
write(DefaultClassResolver.NAME + 2) write(DefaultClassResolver.NAME + 2)
write(nameID) write(nameID)
write(javaEmptySetClass.name.toAscii()) write(javaEmptySetClass.name.toAscii())

View File

@ -2,6 +2,8 @@
package net.corda.nodeapi.internal.serialization.amqp package net.corda.nodeapi.internal.serialization.amqp
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
import net.corda.client.rpc.RPCException import net.corda.client.rpc.RPCException
import net.corda.core.CordaRuntimeException import net.corda.core.CordaRuntimeException
import net.corda.core.contracts.* import net.corda.core.contracts.*
@ -11,21 +13,16 @@ import net.corda.core.flows.FlowException
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.AbstractAttachment import net.corda.core.internal.AbstractAttachment
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.*
import net.corda.core.serialization.MissingAttachmentsException
import net.corda.core.serialization.SerializationFactory
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.nodeapi.internal.serialization.AllWhitelist import net.corda.nodeapi.internal.serialization.*
import net.corda.nodeapi.internal.serialization.EmptyWhitelist
import net.corda.nodeapi.internal.serialization.GeneratedAttachment
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.isPrimitive import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.isPrimitive
import net.corda.nodeapi.internal.serialization.amqp.custom.BigDecimalSerializer
import net.corda.nodeapi.internal.serialization.amqp.custom.CurrencySerializer
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.TestIdentity import net.corda.testing.core.TestIdentity
import net.corda.testing.internal.rigorousMock
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.qpid.proton.amqp.* import org.apache.qpid.proton.amqp.*
import org.apache.qpid.proton.codec.DecoderImpl import org.apache.qpid.proton.codec.DecoderImpl
@ -35,22 +32,23 @@ import org.junit.Assert.*
import org.junit.Ignore import org.junit.Ignore
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
import java.io.IOException import java.io.IOException
import java.io.NotSerializableException import java.io.NotSerializableException
import java.lang.reflect.Type
import java.math.BigDecimal import java.math.BigDecimal
import java.nio.ByteBuffer
import java.time.* import java.time.*
import java.time.temporal.ChronoUnit import java.time.temporal.ChronoUnit
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.full.superclasses import kotlin.reflect.full.superclasses
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertNotNull import kotlin.test.assertNotNull
import kotlin.test.assertTrue import kotlin.test.assertTrue
class SerializationOutputTests { @RunWith(Parameterized::class)
class SerializationOutputTests(private val compression: CordaSerializationEncoding?) {
private companion object { private companion object {
val BOB_IDENTITY = TestIdentity(BOB_NAME, 80).identity val BOB_IDENTITY = TestIdentity(BOB_NAME, 80).identity
val megaCorp = TestIdentity(CordaX500Name("MegaCorp", "London", "GB")) val megaCorp = TestIdentity(CordaX500Name("MegaCorp", "London", "GB"))
@ -59,6 +57,9 @@ class SerializationOutputTests {
val MEGA_CORP_PUBKEY get() = megaCorp.publicKey val MEGA_CORP_PUBKEY get() = megaCorp.publicKey
val MINI_CORP get() = miniCorp.party val MINI_CORP get() = miniCorp.party
val MINI_CORP_PUBKEY get() = miniCorp.publicKey val MINI_CORP_PUBKEY get() = miniCorp.publicKey
@Parameters(name = "{0}")
@JvmStatic
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
} }
@Rule @Rule
@ -173,16 +174,20 @@ class SerializationOutputTests {
} }
} }
private val encodingWhitelist = rigorousMock<EncodingWhitelist>().also {
if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression)
}
private fun defaultFactory() = SerializerFactory(
AllWhitelist, ClassLoader.getSystemClassLoader(),
EvolutionSerializerGetterTesting())
private inline fun <reified T : Any> serdes(obj: T, private inline fun <reified T : Any> serdes(obj: T,
factory: SerializerFactory = SerializerFactory( factory: SerializerFactory = defaultFactory(),
AllWhitelist, ClassLoader.getSystemClassLoader(), freshDeserializationFactory: SerializerFactory = defaultFactory(),
EvolutionSerializerGetterTesting()),
freshDeserializationFactory: SerializerFactory = SerializerFactory(
AllWhitelist, ClassLoader.getSystemClassLoader(),
EvolutionSerializerGetterTesting()),
expectedEqual: Boolean = true, expectedEqual: Boolean = true,
expectDeserializedEqual: Boolean = true): T { expectDeserializedEqual: Boolean = true): T {
val ser = SerializationOutput(factory) val ser = SerializationOutput(factory, compression)
val bytes = ser.serialize(obj) val bytes = ser.serialize(obj)
val decoder = DecoderImpl().apply { val decoder = DecoderImpl().apply {
@ -198,18 +203,19 @@ class SerializationOutputTests {
this.register(TransformTypes.DESCRIPTOR, TransformTypes.Companion) this.register(TransformTypes.DESCRIPTOR, TransformTypes.Companion)
} }
EncoderImpl(decoder) EncoderImpl(decoder)
decoder.setByteBuffer(ByteBuffer.wrap(bytes.bytes, 8, bytes.size - 8)) DeserializationInput.withDataBytes(bytes, encodingWhitelist) {
// Check that a vanilla AMQP decoder can deserialize without schema. decoder.setByteBuffer(it)
val result = decoder.readObject() as Envelope // Check that a vanilla AMQP decoder can deserialize without schema.
assertNotNull(result) val result = decoder.readObject() as Envelope
assertNotNull(result)
val des = DeserializationInput(freshDeserializationFactory) }
val des = DeserializationInput(freshDeserializationFactory, encodingWhitelist)
val desObj = des.deserialize(bytes) val desObj = des.deserialize(bytes)
assertTrue(Objects.deepEquals(obj, desObj) == expectedEqual) assertTrue(Objects.deepEquals(obj, desObj) == expectedEqual)
// Now repeat with a re-used factory // Now repeat with a re-used factory
val ser2 = SerializationOutput(factory) val ser2 = SerializationOutput(factory, compression)
val des2 = DeserializationInput(factory) val des2 = DeserializationInput(factory, encodingWhitelist)
val desObj2 = des2.deserialize(ser2.serialize(obj)) val desObj2 = des2.deserialize(ser2.serialize(obj))
assertTrue(Objects.deepEquals(obj, desObj2) == expectedEqual) assertTrue(Objects.deepEquals(obj, desObj2) == expectedEqual)
assertTrue(Objects.deepEquals(desObj, desObj2) == expectDeserializedEqual) assertTrue(Objects.deepEquals(desObj, desObj2) == expectDeserializedEqual)
@ -432,9 +438,9 @@ class SerializationOutputTests {
@Test @Test
fun `class constructor is invoked on deserialisation`() { fun `class constructor is invoked on deserialisation`() {
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())) compression == null || return // Manipulation of serialized bytes is invalid if they're compressed.
val des = DeserializationInput(ser.serializerFactory) val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()), compression)
val des = DeserializationInput(ser.serializerFactory, encodingWhitelist)
val serialisedOne = ser.serialize(NonZeroByte(1)).bytes val serialisedOne = ser.serialize(NonZeroByte(1)).bytes
val serialisedTwo = ser.serialize(NonZeroByte(2)).bytes val serialisedTwo = ser.serialize(NonZeroByte(2)).bytes
@ -1116,6 +1122,29 @@ class SerializationOutputTests {
val c = C(Amount<Currency>(100, BigDecimal("1.5"), Currency.getInstance("USD"))) val c = C(Amount<Currency>(100, BigDecimal("1.5"), Currency.getInstance("USD")))
// were the issue not fixed we'd blow up here // were the issue not fixed we'd blow up here
SerializationOutput(factory).serialize(c) SerializationOutput(factory, compression).serialize(c)
}
@Test
fun `compression has the desired effect`() {
compression ?: return
val factory = defaultFactory()
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
val compressed = SerializationOutput(factory, compression).serialize(data)
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
assertArrayEquals(data, DeserializationInput(factory, encodingWhitelist).deserialize(compressed))
}
@Test
fun `a particular encoding can be banned for deserialization`() {
compression ?: return
val factory = defaultFactory()
doReturn(false).whenever(encodingWhitelist).acceptEncoding(compression)
val compressed = SerializationOutput(factory, compression).serialize("whatever")
val input = DeserializationInput(factory, encodingWhitelist)
catchThrowable { input.deserialize(compressed) }.run {
assertSame(NotSerializableException::class.java, javaClass)
assertEquals(encodingNotPermittedFormat.format(compression), message)
}
} }
} }

View File

@ -1,12 +1,16 @@
package net.corda.nodeapi.internal.serialization.kryo package net.corda.nodeapi.internal.serialization.kryo
import net.corda.core.internal.declaredField
import org.assertj.core.api.Assertions.catchThrowable
import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertArrayEquals
import org.junit.Test import org.junit.Test
import java.io.* import java.io.*
import java.nio.BufferOverflowException
import java.util.* import java.util.*
import java.util.zip.DeflaterOutputStream import java.util.zip.DeflaterOutputStream
import java.util.zip.InflaterInputStream import java.util.zip.InflaterInputStream
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertSame
class KryoStreamsTest { class KryoStreamsTest {
class NegOutputStream(private val stream: OutputStream) : OutputStream() { class NegOutputStream(private val stream: OutputStream) : OutputStream() {
@ -57,4 +61,37 @@ class KryoStreamsTest {
assertEquals(-1, read()) assertEquals(-1, read())
} }
} }
@Test
fun `ByteBufferOutputStream works`() {
val stream = ByteBufferOutputStream(3)
stream.write("abc".toByteArray())
val getBuf = stream.declaredField<ByteArray>(ByteArrayOutputStream::class, "buf")::value
assertEquals(3, getBuf().size)
repeat(2) {
assertSame<Any>(BufferOverflowException::class.java, catchThrowable {
stream.alsoAsByteBuffer(9) {
it.put("0123456789".toByteArray())
}
}.javaClass)
assertEquals(3 + 9, getBuf().size)
}
// This time make too much space:
stream.alsoAsByteBuffer(11) {
it.put("0123456789".toByteArray())
}
stream.write("def".toByteArray())
assertArrayEquals("abc0123456789def".toByteArray(), stream.toByteArray())
}
@Test
fun `ByteBufferOutputStream discards data after final position`() {
val stream = ByteBufferOutputStream(0)
stream.alsoAsByteBuffer(10) {
it.put("0123456789".toByteArray())
it.position(5)
}
stream.write("def".toByteArray())
assertArrayEquals("01234def".toByteArray(), stream.toByteArray())
}
} }