mirror of
https://github.com/corda/corda.git
synced 2025-05-21 01:37:41 +00:00
CORDA-973 Compression support for serialization (#2473)
* Serialization magic is now 7 bytes * Introduce encoding property and whitelist
This commit is contained in:
parent
2af0feee04
commit
c8672d373f
@ -2817,7 +2817,7 @@ public final class net.corda.core.serialization.ObjectWithCompatibleContext exte
|
||||
public final class net.corda.core.serialization.SerializationAPIKt extends java.lang.Object
|
||||
@org.jetbrains.annotations.NotNull public static final net.corda.core.serialization.SerializedBytes serialize(Object, net.corda.core.serialization.SerializationFactory, net.corda.core.serialization.SerializationContext)
|
||||
##
|
||||
public interface net.corda.core.serialization.SerializationContext
|
||||
@net.corda.core.DoNotImplement public interface net.corda.core.serialization.SerializationContext
|
||||
@org.jetbrains.annotations.NotNull public abstract ClassLoader getDeserializationClassLoader()
|
||||
public abstract boolean getObjectReferencesEnabled()
|
||||
@org.jetbrains.annotations.NotNull public abstract net.corda.core.utilities.ByteSequence getPreferredSerializationVersion()
|
||||
|
@ -1,5 +1,6 @@
|
||||
package net.corda.core.serialization
|
||||
|
||||
import net.corda.core.DoNotImplement
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.sha256
|
||||
import net.corda.core.serialization.internal.effectiveSerializationEnv
|
||||
@ -99,14 +100,22 @@ abstract class SerializationFactory {
|
||||
}
|
||||
}
|
||||
typealias SerializationMagic = ByteSequence
|
||||
@DoNotImplement
|
||||
interface SerializationEncoding
|
||||
|
||||
/**
|
||||
* Parameters to serialization and deserialization.
|
||||
*/
|
||||
@DoNotImplement
|
||||
interface SerializationContext {
|
||||
/**
|
||||
* When serializing, use the format this header sequence represents.
|
||||
*/
|
||||
val preferredSerializationVersion: SerializationMagic
|
||||
/**
|
||||
* If non-null, apply this encoding (typically compression) when serializing.
|
||||
*/
|
||||
val encoding: SerializationEncoding?
|
||||
/**
|
||||
* The class loader to use for deserialization.
|
||||
*/
|
||||
@ -115,6 +124,10 @@ interface SerializationContext {
|
||||
* A whitelist that contains (mostly for security purposes) which classes can be serialized and deserialized.
|
||||
*/
|
||||
val whitelist: ClassWhitelist
|
||||
/**
|
||||
* A whitelist that determines (mostly for security purposes) whether a particular encoding may be used when deserializing.
|
||||
*/
|
||||
val encodingWhitelist: EncodingWhitelist
|
||||
/**
|
||||
* A map of any addition properties specific to the particular use case.
|
||||
*/
|
||||
@ -161,6 +174,11 @@ interface SerializationContext {
|
||||
*/
|
||||
fun withPreferredSerializationVersion(magic: SerializationMagic): SerializationContext
|
||||
|
||||
/**
|
||||
* A shallow copy of this context but with the given (possibly null) encoding.
|
||||
*/
|
||||
fun withEncoding(encoding: SerializationEncoding?): SerializationContext
|
||||
|
||||
/**
|
||||
* The use case that we are serializing for, since it influences the implementations chosen.
|
||||
*/
|
||||
@ -232,3 +250,8 @@ class SerializedBytes<T : Any>(bytes: ByteArray) : OpaqueBytes(bytes) {
|
||||
interface ClassWhitelist {
|
||||
fun hasListed(type: Class<*>): Boolean
|
||||
}
|
||||
|
||||
@DoNotImplement
|
||||
interface EncodingWhitelist {
|
||||
fun acceptEncoding(encoding: SerializationEncoding): Boolean
|
||||
}
|
||||
|
@ -34,6 +34,9 @@ dependencies {
|
||||
// For AMQP serialisation.
|
||||
compile "org.apache.qpid:proton-j:0.21.0"
|
||||
|
||||
// Pure-Java Snappy compression
|
||||
compile 'org.iq80.snappy:snappy:0.4'
|
||||
|
||||
// Unit testing helpers.
|
||||
testCompile "junit:junit:$junit_version"
|
||||
testCompile "org.assertj:assertj-core:$assertj_version"
|
||||
|
@ -18,10 +18,12 @@ val KRYO_RPC_CLIENT_CONTEXT = SerializationContextImpl(kryoMagic,
|
||||
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.RPCClient)
|
||||
SerializationContext.UseCase.RPCClient,
|
||||
null)
|
||||
val AMQP_RPC_CLIENT_CONTEXT = SerializationContextImpl(amqpMagic,
|
||||
SerializationDefaults.javaClass.classLoader,
|
||||
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.RPCClient)
|
||||
SerializationContext.UseCase.RPCClient,
|
||||
null)
|
||||
|
@ -0,0 +1,31 @@
|
||||
package net.corda.nodeapi.internal.serialization
|
||||
|
||||
import java.io.EOFException
|
||||
import java.io.InputStream
|
||||
import java.io.OutputStream
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
class OrdinalBits(private val ordinal: Int) {
|
||||
interface OrdinalWriter {
|
||||
val bits: OrdinalBits
|
||||
val encodedSize get() = 1
|
||||
fun writeTo(stream: OutputStream) = stream.write(bits.ordinal)
|
||||
fun putTo(buffer: ByteBuffer) = buffer.put(bits.ordinal.toByte())!!
|
||||
}
|
||||
|
||||
init {
|
||||
require(ordinal >= 0) { "The ordinal must be non-negative." }
|
||||
require(ordinal < 128) { "Consider implementing a varint encoding." }
|
||||
}
|
||||
}
|
||||
|
||||
class OrdinalReader<out E : Any>(private val values: Array<E>) {
|
||||
private val enumName = values[0].javaClass.simpleName
|
||||
private val range = 0 until values.size
|
||||
fun readFrom(stream: InputStream): E {
|
||||
val ordinal = stream.read()
|
||||
if (ordinal == -1) throw EOFException("Expected a $enumName ordinal.")
|
||||
if (ordinal !in range) throw NoSuchElementException("No $enumName with ordinal: $ordinal")
|
||||
return values[ordinal]
|
||||
}
|
||||
}
|
@ -1,8 +1,17 @@
|
||||
package net.corda.nodeapi.internal.serialization
|
||||
|
||||
import net.corda.core.internal.VisibleForTesting
|
||||
import net.corda.core.serialization.SerializationEncoding
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.core.utilities.OpaqueBytes
|
||||
import net.corda.nodeapi.internal.serialization.OrdinalBits.OrdinalWriter
|
||||
import org.iq80.snappy.SnappyFramedInputStream
|
||||
import org.iq80.snappy.SnappyFramedOutputStream
|
||||
import java.io.OutputStream
|
||||
import java.io.InputStream
|
||||
import java.nio.ByteBuffer
|
||||
import java.util.zip.DeflaterOutputStream
|
||||
import java.util.zip.InflaterInputStream
|
||||
|
||||
class CordaSerializationMagic(bytes: ByteArray) : OpaqueBytes(bytes) {
|
||||
private val bufferView = slice()
|
||||
@ -10,3 +19,40 @@ class CordaSerializationMagic(bytes: ByteArray) : OpaqueBytes(bytes) {
|
||||
return if (data.slice(end = size) == bufferView) data.slice(size) else null
|
||||
}
|
||||
}
|
||||
|
||||
enum class SectionId : OrdinalWriter {
|
||||
/** Serialization data follows, and then discard the rest of the stream (if any) as legacy data may have trailing garbage. */
|
||||
DATA_AND_STOP,
|
||||
/** Identical behaviour to [DATA_AND_STOP], historically used for Kryo. Do not use in new code. */
|
||||
ALT_DATA_AND_STOP,
|
||||
/** The ordinal of a [CordaSerializationEncoding] follows, which should be used to decode the remainder of the stream. */
|
||||
ENCODING;
|
||||
|
||||
companion object {
|
||||
val reader = OrdinalReader(values())
|
||||
}
|
||||
|
||||
override val bits = OrdinalBits(ordinal)
|
||||
}
|
||||
|
||||
enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter {
|
||||
DEFLATE {
|
||||
override fun wrap(stream: OutputStream) = DeflaterOutputStream(stream)
|
||||
override fun wrap(stream: InputStream) = InflaterInputStream(stream)
|
||||
},
|
||||
SNAPPY {
|
||||
override fun wrap(stream: OutputStream) = SnappyFramedOutputStream(stream)
|
||||
override fun wrap(stream: InputStream) = SnappyFramedInputStream(stream, false)
|
||||
};
|
||||
|
||||
companion object {
|
||||
val reader = OrdinalReader(values())
|
||||
}
|
||||
|
||||
override val bits = OrdinalBits(ordinal)
|
||||
abstract fun wrap(stream: OutputStream): OutputStream
|
||||
abstract fun wrap(stream: InputStream): InputStream
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
internal val encodingNotPermittedFormat = "Encoding not permitted: %s"
|
||||
|
@ -18,13 +18,18 @@ import java.util.concurrent.ExecutionException
|
||||
|
||||
val attachmentsClassLoaderEnabledPropertyName = "attachments.class.loader.enabled"
|
||||
|
||||
data class SerializationContextImpl(override val preferredSerializationVersion: SerializationMagic,
|
||||
override val deserializationClassLoader: ClassLoader,
|
||||
override val whitelist: ClassWhitelist,
|
||||
override val properties: Map<Any, Any>,
|
||||
override val objectReferencesEnabled: Boolean,
|
||||
override val useCase: SerializationContext.UseCase) : SerializationContext {
|
||||
internal object NullEncodingWhitelist : EncodingWhitelist {
|
||||
override fun acceptEncoding(encoding: SerializationEncoding) = false
|
||||
}
|
||||
|
||||
data class SerializationContextImpl @JvmOverloads constructor(override val preferredSerializationVersion: SerializationMagic,
|
||||
override val deserializationClassLoader: ClassLoader,
|
||||
override val whitelist: ClassWhitelist,
|
||||
override val properties: Map<Any, Any>,
|
||||
override val objectReferencesEnabled: Boolean,
|
||||
override val useCase: SerializationContext.UseCase,
|
||||
override val encoding: SerializationEncoding?,
|
||||
override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) : SerializationContext {
|
||||
private val cache: Cache<List<SecureHash>, AttachmentsClassLoader> = CacheBuilder.newBuilder().weakValues().maximumSize(1024).build()
|
||||
|
||||
/**
|
||||
@ -70,6 +75,7 @@ data class SerializationContextImpl(override val preferredSerializationVersion:
|
||||
}
|
||||
|
||||
override fun withPreferredSerializationVersion(magic: SerializationMagic) = copy(preferredSerializationVersion = magic)
|
||||
override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding)
|
||||
}
|
||||
|
||||
open class SerializationFactoryImpl : SerializationFactory() {
|
||||
|
@ -27,22 +27,26 @@ val KRYO_RPC_SERVER_CONTEXT = SerializationContextImpl(kryoMagic,
|
||||
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.RPCServer)
|
||||
SerializationContext.UseCase.RPCServer,
|
||||
null)
|
||||
val KRYO_STORAGE_CONTEXT = SerializationContextImpl(kryoMagic,
|
||||
SerializationDefaults.javaClass.classLoader,
|
||||
AllButBlacklisted,
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.Storage)
|
||||
SerializationContext.UseCase.Storage,
|
||||
null)
|
||||
val AMQP_STORAGE_CONTEXT = SerializationContextImpl(amqpMagic,
|
||||
SerializationDefaults.javaClass.classLoader,
|
||||
AllButBlacklisted,
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.Storage)
|
||||
SerializationContext.UseCase.Storage,
|
||||
null)
|
||||
val AMQP_RPC_SERVER_CONTEXT = SerializationContextImpl(amqpMagic,
|
||||
SerializationDefaults.javaClass.classLoader,
|
||||
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.RPCServer)
|
||||
SerializationContext.UseCase.RPCServer,
|
||||
null)
|
||||
|
@ -20,18 +20,19 @@ val KRYO_P2P_CONTEXT = SerializationContextImpl(kryoMagic,
|
||||
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.P2P)
|
||||
SerializationContext.UseCase.P2P,
|
||||
null)
|
||||
val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl(kryoMagic,
|
||||
SerializationDefaults.javaClass.classLoader,
|
||||
QuasarWhitelist,
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.Checkpoint)
|
||||
SerializationContext.UseCase.Checkpoint,
|
||||
null)
|
||||
val AMQP_P2P_CONTEXT = SerializationContextImpl(amqpMagic,
|
||||
SerializationDefaults.javaClass.classLoader,
|
||||
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.P2P)
|
||||
|
||||
|
||||
SerializationContext.UseCase.P2P,
|
||||
null)
|
||||
|
@ -0,0 +1,31 @@
|
||||
package net.corda.nodeapi.internal.serialization.amqp
|
||||
|
||||
import com.esotericsoftware.kryo.io.ByteBufferInputStream
|
||||
import net.corda.nodeapi.internal.serialization.kryo.ByteBufferOutputStream
|
||||
import net.corda.nodeapi.internal.serialization.kryo.serializeOutputStreamPool
|
||||
import java.io.InputStream
|
||||
import java.io.OutputStream
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
fun InputStream.asByteBuffer(): ByteBuffer {
|
||||
return if (this is ByteBufferInputStream) {
|
||||
byteBuffer // BBIS has no other state, so this is perfectly safe.
|
||||
} else {
|
||||
ByteBuffer.wrap(serializeOutputStreamPool.run {
|
||||
copyTo(it)
|
||||
it.toByteArray()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fun <T> OutputStream.alsoAsByteBuffer(remaining: Int, task: (ByteBuffer) -> T): T {
|
||||
return if (this is ByteBufferOutputStream) {
|
||||
alsoAsByteBuffer(remaining, task)
|
||||
} else {
|
||||
serializeOutputStreamPool.run {
|
||||
val result = it.alsoAsByteBuffer(remaining, task)
|
||||
it.copyTo(this)
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
@ -1,18 +1,27 @@
|
||||
package net.corda.nodeapi.internal.serialization.amqp
|
||||
|
||||
import com.esotericsoftware.kryo.io.ByteBufferInputStream
|
||||
import net.corda.core.internal.VisibleForTesting
|
||||
import net.corda.core.internal.getStackTraceAsString
|
||||
import net.corda.core.serialization.EncodingWhitelist
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.nodeapi.internal.serialization.CordaSerializationEncoding
|
||||
import net.corda.nodeapi.internal.serialization.NullEncodingWhitelist
|
||||
import net.corda.nodeapi.internal.serialization.SectionId
|
||||
import net.corda.nodeapi.internal.serialization.encodingNotPermittedFormat
|
||||
import org.apache.qpid.proton.amqp.Binary
|
||||
import org.apache.qpid.proton.amqp.DescribedType
|
||||
import org.apache.qpid.proton.amqp.UnsignedByte
|
||||
import org.apache.qpid.proton.amqp.UnsignedInteger
|
||||
import org.apache.qpid.proton.codec.Data
|
||||
import java.io.InputStream
|
||||
import java.io.NotSerializableException
|
||||
import java.lang.reflect.ParameterizedType
|
||||
import java.lang.reflect.Type
|
||||
import java.lang.reflect.TypeVariable
|
||||
import java.lang.reflect.WildcardType
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
data class ObjectAndEnvelope<out T>(val obj: T, val envelope: Envelope)
|
||||
|
||||
@ -22,7 +31,8 @@ data class ObjectAndEnvelope<out T>(val obj: T, val envelope: Envelope)
|
||||
* @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple
|
||||
* instances and threads.
|
||||
*/
|
||||
class DeserializationInput(internal val serializerFactory: SerializerFactory) {
|
||||
class DeserializationInput @JvmOverloads constructor(private val serializerFactory: SerializerFactory,
|
||||
private val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) {
|
||||
private val objectHistory: MutableList<Any> = mutableListOf()
|
||||
|
||||
internal companion object {
|
||||
@ -47,6 +57,28 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) {
|
||||
}
|
||||
return size + BYTES_NEEDED_TO_PEEK
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
@Throws(NotSerializableException::class)
|
||||
internal fun <T> withDataBytes(byteSequence: ByteSequence, encodingWhitelist: EncodingWhitelist, task: (ByteBuffer) -> T): T {
|
||||
// Check that the lead bytes match expected header
|
||||
val amqpSequence = amqpMagic.consume(byteSequence) ?: throw NotSerializableException("Serialization header does not match.")
|
||||
var stream: InputStream = ByteBufferInputStream(amqpSequence)
|
||||
try {
|
||||
while (true) {
|
||||
when (SectionId.reader.readFrom(stream)) {
|
||||
SectionId.ENCODING -> {
|
||||
val encoding = CordaSerializationEncoding.reader.readFrom(stream)
|
||||
encodingWhitelist.acceptEncoding(encoding) || throw NotSerializableException(encodingNotPermittedFormat.format(encoding))
|
||||
stream = encoding.wrap(stream)
|
||||
}
|
||||
SectionId.DATA_AND_STOP, SectionId.ALT_DATA_AND_STOP -> return task(stream.asByteBuffer())
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
stream.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Throws(NotSerializableException::class)
|
||||
@ -58,12 +90,12 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) {
|
||||
|
||||
@Throws(NotSerializableException::class)
|
||||
internal fun getEnvelope(byteSequence: ByteSequence): Envelope {
|
||||
// Check that the lead bytes match expected header
|
||||
val dataBytes = amqpMagic.consume(byteSequence) ?: throw NotSerializableException("Serialization header does not match.")
|
||||
val data = Data.Factory.create()
|
||||
val expectedSize = dataBytes.remaining()
|
||||
if (data.decode(dataBytes) != expectedSize.toLong()) throw NotSerializableException("Unexpected size of data")
|
||||
return Envelope.get(data)
|
||||
return withDataBytes(byteSequence, encodingWhitelist) { dataBytes ->
|
||||
val data = Data.Factory.create()
|
||||
val expectedSize = dataBytes.remaining()
|
||||
if (data.decode(dataBytes) != expectedSize.toLong()) throw NotSerializableException("Unexpected size of data")
|
||||
Envelope.get(data)
|
||||
}
|
||||
}
|
||||
|
||||
@Throws(NotSerializableException::class)
|
||||
|
@ -12,7 +12,7 @@ import net.corda.nodeapi.internal.serialization.carpenter.Field as CarpenterFiel
|
||||
import net.corda.nodeapi.internal.serialization.carpenter.Schema as CarpenterSchema
|
||||
|
||||
const val DESCRIPTOR_DOMAIN: String = "net.corda"
|
||||
val amqpMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(1, 0, 0))
|
||||
val amqpMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(1, 0))
|
||||
|
||||
/**
|
||||
* This and the classes below are OO representations of the AMQP XML schema described in the specification. Their
|
||||
|
@ -1,10 +1,14 @@
|
||||
package net.corda.nodeapi.internal.serialization.amqp
|
||||
|
||||
import net.corda.core.serialization.SerializationEncoding
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.nodeapi.internal.serialization.CordaSerializationEncoding
|
||||
import net.corda.nodeapi.internal.serialization.SectionId
|
||||
import net.corda.nodeapi.internal.serialization.kryo.byteArrayOutput
|
||||
import org.apache.qpid.proton.codec.Data
|
||||
import java.io.NotSerializableException
|
||||
import java.io.OutputStream
|
||||
import java.lang.reflect.Type
|
||||
import java.nio.ByteBuffer
|
||||
import java.util.*
|
||||
import kotlin.collections.LinkedHashSet
|
||||
|
||||
@ -19,8 +23,7 @@ data class BytesAndSchemas<T : Any>(
|
||||
* @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple
|
||||
* instances and threads.
|
||||
*/
|
||||
open class SerializationOutput(internal val serializerFactory: SerializerFactory) {
|
||||
|
||||
open class SerializationOutput @JvmOverloads constructor(internal val serializerFactory: SerializerFactory, private val encoding: SerializationEncoding? = null) {
|
||||
private val objectHistory: MutableMap<Any, Int> = IdentityHashMap()
|
||||
private val serializerHistory: MutableSet<AMQPSerializer<*>> = LinkedHashSet()
|
||||
internal val schemaHistory: MutableSet<TypeNotation> = LinkedHashSet()
|
||||
@ -67,11 +70,21 @@ open class SerializationOutput(internal val serializerFactory: SerializerFactory
|
||||
writeTransformSchema(TransformsSchema.build(schema, serializerFactory), this)
|
||||
}
|
||||
}
|
||||
val bytes = ByteArray(data.encodedSize().toInt() + 8)
|
||||
val buf = ByteBuffer.wrap(bytes)
|
||||
amqpMagic.putTo(buf)
|
||||
data.encode(buf)
|
||||
return SerializedBytes(bytes)
|
||||
return SerializedBytes(byteArrayOutput {
|
||||
var stream: OutputStream = it
|
||||
try {
|
||||
amqpMagic.writeTo(stream)
|
||||
if (encoding != null) {
|
||||
SectionId.ENCODING.writeTo(stream)
|
||||
(encoding as CordaSerializationEncoding).writeTo(stream)
|
||||
stream = encoding.wrap(stream)
|
||||
}
|
||||
SectionId.DATA_AND_STOP.writeTo(stream)
|
||||
stream.alsoAsByteBuffer(data.encodedSize().toInt(), data::encode)
|
||||
} finally {
|
||||
stream.close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
internal fun writeObject(obj: Any, data: Data) {
|
||||
|
@ -16,10 +16,12 @@ import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.core.serialization.*
|
||||
import net.corda.nodeapi.internal.serialization.CordaSerializationMagic
|
||||
import net.corda.nodeapi.internal.serialization.CordaClassResolver
|
||||
import net.corda.nodeapi.internal.serialization.SectionId
|
||||
import net.corda.nodeapi.internal.serialization.SerializationScheme
|
||||
import net.corda.nodeapi.internal.serialization.*
|
||||
import java.security.PublicKey
|
||||
|
||||
val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0, 1))
|
||||
val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0))
|
||||
|
||||
private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>() {
|
||||
override fun write(kryo: Kryo, output: Output, closeable: AutoCloseable) {
|
||||
@ -87,11 +89,25 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
|
||||
val dataBytes = kryoMagic.consume(byteSequence) ?: throw KryoException("Serialized bytes header does not match expected format.")
|
||||
return context.kryo {
|
||||
kryoInput(ByteBufferInputStream(dataBytes)) {
|
||||
if (context.objectReferencesEnabled) {
|
||||
uncheckedCast(readClassAndObject(this))
|
||||
} else {
|
||||
withoutReferences { uncheckedCast<Any?, T>(readClassAndObject(this)) }
|
||||
val result: T
|
||||
loop@ while (true) {
|
||||
when (SectionId.reader.readFrom(this)) {
|
||||
SectionId.ENCODING -> {
|
||||
val encoding = CordaSerializationEncoding.reader.readFrom(this)
|
||||
context.encodingWhitelist.acceptEncoding(encoding) || throw KryoException(encodingNotPermittedFormat.format(encoding))
|
||||
substitute(encoding::wrap)
|
||||
}
|
||||
SectionId.DATA_AND_STOP, SectionId.ALT_DATA_AND_STOP -> {
|
||||
result = if (context.objectReferencesEnabled) {
|
||||
uncheckedCast(readClassAndObject(this))
|
||||
} else {
|
||||
withoutReferences { uncheckedCast<Any?, T>(readClassAndObject(this)) }
|
||||
}
|
||||
break@loop
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -100,6 +116,12 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
|
||||
return context.kryo {
|
||||
SerializedBytes(kryoOutput {
|
||||
kryoMagic.writeTo(this)
|
||||
context.encoding?.let { encoding ->
|
||||
SectionId.ENCODING.writeTo(this)
|
||||
(encoding as CordaSerializationEncoding).writeTo(this)
|
||||
substitute(encoding::wrap)
|
||||
}
|
||||
SectionId.ALT_DATA_AND_STOP.writeTo(this) // Forward-compatible in null-encoding case.
|
||||
if (context.objectReferencesEnabled) {
|
||||
writeClassAndObject(this, obj)
|
||||
} else {
|
||||
|
@ -4,13 +4,34 @@ import com.esotericsoftware.kryo.io.Input
|
||||
import com.esotericsoftware.kryo.io.Output
|
||||
import net.corda.core.internal.LazyPool
|
||||
import java.io.*
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
class ByteBufferOutputStream(size: Int) : ByteArrayOutputStream(size) {
|
||||
companion object {
|
||||
private val ensureCapacity = ByteArrayOutputStream::class.java.getDeclaredMethod("ensureCapacity", Int::class.java).apply {
|
||||
isAccessible = true
|
||||
}
|
||||
}
|
||||
|
||||
fun <T> alsoAsByteBuffer(remaining: Int, task: (ByteBuffer) -> T): T {
|
||||
ensureCapacity.invoke(this, count + remaining)
|
||||
val buffer = ByteBuffer.wrap(buf, count, remaining)
|
||||
val result = task(buffer)
|
||||
count = buffer.position()
|
||||
return result
|
||||
}
|
||||
|
||||
fun copyTo(stream: OutputStream) {
|
||||
stream.write(buf, 0, count)
|
||||
}
|
||||
}
|
||||
|
||||
private val serializationBufferPool = LazyPool(
|
||||
newInstance = { ByteArray(64 * 1024) })
|
||||
private val serializeOutputStreamPool = LazyPool(
|
||||
clear = ByteArrayOutputStream::reset,
|
||||
internal val serializeOutputStreamPool = LazyPool(
|
||||
clear = ByteBufferOutputStream::reset,
|
||||
shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large
|
||||
newInstance = { ByteArrayOutputStream(64 * 1024) })
|
||||
newInstance = { ByteBufferOutputStream(64 * 1024) })
|
||||
|
||||
internal fun <T> kryoInput(underlying: InputStream, task: Input.() -> T): T {
|
||||
return serializationBufferPool.run {
|
||||
@ -22,13 +43,19 @@ internal fun <T> kryoInput(underlying: InputStream, task: Input.() -> T): T {
|
||||
}
|
||||
|
||||
internal fun <T> kryoOutput(task: Output.() -> T): ByteArray {
|
||||
return serializeOutputStreamPool.run { underlying ->
|
||||
return byteArrayOutput { underlying ->
|
||||
serializationBufferPool.run {
|
||||
Output(it).use { output ->
|
||||
output.outputStream = underlying
|
||||
output.task()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal fun <T> byteArrayOutput(task: (ByteBufferOutputStream) -> T): ByteArray {
|
||||
return serializeOutputStreamPool.run { underlying ->
|
||||
task(underlying)
|
||||
underlying.toByteArray() // Must happen after close, to allow ZIP footer to be written for example.
|
||||
}
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ public final class ForbiddenLambdaSerializationTests {
|
||||
EnumSet<SerializationContext.UseCase> contexts = EnumSet.complementOf(EnumSet.of(SerializationContext.UseCase.Checkpoint));
|
||||
|
||||
contexts.forEach(ctx -> {
|
||||
SerializationContext context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx);
|
||||
SerializationContext context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx, null);
|
||||
String value = "Hey";
|
||||
Callable<String> target = (Callable<String> & Serializable) () -> value;
|
||||
|
||||
@ -55,7 +55,7 @@ public final class ForbiddenLambdaSerializationTests {
|
||||
EnumSet<SerializationContext.UseCase> contexts = EnumSet.complementOf(EnumSet.of(SerializationContext.UseCase.Checkpoint));
|
||||
|
||||
contexts.forEach(ctx -> {
|
||||
SerializationContext context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx);
|
||||
SerializationContext context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx, null);
|
||||
String value = "Hey";
|
||||
Callable<String> target = () -> value;
|
||||
|
||||
|
@ -26,7 +26,7 @@ public final class LambdaCheckpointSerializationTest {
|
||||
@Before
|
||||
public void setup() {
|
||||
factory = testSerialization.getSerializationFactory();
|
||||
context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, SerializationContext.UseCase.Checkpoint);
|
||||
context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, SerializationContext.UseCase.Checkpoint, null);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
@ -319,7 +319,8 @@ class X509UtilitiesTest {
|
||||
AllWhitelist,
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.P2P)
|
||||
SerializationContext.UseCase.P2P,
|
||||
null)
|
||||
val expected = X509Utilities.createSelfSignedCACertificate(ALICE.name.x500Principal, Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME))
|
||||
val serialized = expected.serialize(factory, context).bytes
|
||||
val actual = serialized.deserialize<X509Certificate>(factory, context)
|
||||
@ -334,7 +335,8 @@ class X509UtilitiesTest {
|
||||
AllWhitelist,
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.P2P)
|
||||
SerializationContext.UseCase.P2P,
|
||||
null)
|
||||
val rootCAKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)
|
||||
val rootCACert = X509Utilities.createSelfSignedCACertificate(ALICE_NAME.x500Principal, rootCAKey)
|
||||
val certificate = X509Utilities.createCertificate(CertificateType.TLS, rootCACert, rootCAKey, BOB_NAME.x500Principal, BOB.publicKey)
|
||||
|
@ -108,8 +108,8 @@ class CordaClassResolverTests {
|
||||
val emptyMapClass = mapOf<Any, Any>().javaClass
|
||||
}
|
||||
|
||||
private val emptyWhitelistContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P)
|
||||
private val allButBlacklistedContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P)
|
||||
private val emptyWhitelistContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P, null)
|
||||
private val allButBlacklistedContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P, null)
|
||||
@Test
|
||||
fun `Annotation on enum works for specialised entries`() {
|
||||
CordaClassResolver(emptyWhitelistContext).getRegistration(Foo.Bar::class.java)
|
||||
|
@ -1,10 +1,13 @@
|
||||
package net.corda.nodeapi.internal.serialization
|
||||
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.esotericsoftware.kryo.KryoException
|
||||
import com.esotericsoftware.kryo.KryoSerializable
|
||||
import com.esotericsoftware.kryo.io.Input
|
||||
import com.esotericsoftware.kryo.io.Output
|
||||
import com.google.common.primitives.Ints
|
||||
import com.nhaarman.mockito_kotlin.doReturn
|
||||
import com.nhaarman.mockito_kotlin.whenever
|
||||
import net.corda.core.contracts.PrivacySalt
|
||||
import net.corda.core.crypto.*
|
||||
import net.corda.core.internal.FetchDataFlow
|
||||
@ -16,24 +19,29 @@ import net.corda.node.services.persistence.NodeAttachmentService
|
||||
import net.corda.nodeapi.internal.serialization.kryo.kryoMagic
|
||||
import net.corda.testing.core.ALICE_NAME
|
||||
import net.corda.testing.core.TestIdentity
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||
import net.corda.testing.internal.rigorousMock
|
||||
import org.assertj.core.api.Assertions.*
|
||||
import org.junit.Assert.assertArrayEquals
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.junit.runners.Parameterized
|
||||
import org.junit.runners.Parameterized.Parameters
|
||||
import org.slf4j.LoggerFactory
|
||||
import java.io.ByteArrayInputStream
|
||||
import java.io.InputStream
|
||||
import java.time.Instant
|
||||
import java.util.*
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFalse
|
||||
import kotlin.test.assertNotNull
|
||||
import kotlin.test.assertTrue
|
||||
import kotlin.test.*
|
||||
|
||||
class KryoTests {
|
||||
@RunWith(Parameterized::class)
|
||||
class KryoTests(private val compression: CordaSerializationEncoding?) {
|
||||
companion object {
|
||||
private val ALICE_PUBKEY = TestIdentity(ALICE_NAME, 70).publicKey
|
||||
@Parameters(name = "{0}")
|
||||
@JvmStatic
|
||||
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
|
||||
}
|
||||
|
||||
private lateinit var factory: SerializationFactory
|
||||
@ -47,7 +55,11 @@ class KryoTests {
|
||||
AllWhitelist,
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.Storage)
|
||||
SerializationContext.UseCase.Storage,
|
||||
compression,
|
||||
rigorousMock<EncodingWhitelist>().also {
|
||||
if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression)
|
||||
})
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -259,7 +271,8 @@ class KryoTests {
|
||||
AllWhitelist,
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.P2P)
|
||||
SerializationContext.UseCase.P2P,
|
||||
null)
|
||||
pt.serialize(factory, context)
|
||||
}
|
||||
|
||||
@ -300,4 +313,24 @@ class KryoTests {
|
||||
val exception2 = exception.serialize(factory, context).deserialize(factory, context)
|
||||
assertEquals(randomHash, exception2.requested)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `compression has the desired effect`() {
|
||||
compression ?: return
|
||||
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
|
||||
val compressed = data.serialize(factory, context)
|
||||
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
|
||||
assertArrayEquals(data, compressed.deserialize(factory, context))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `a particular encoding can be banned for deserialization`() {
|
||||
compression ?: return
|
||||
doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression)
|
||||
val compressed = "whatever".serialize(factory, context)
|
||||
catchThrowable { compressed.deserialize(factory, context) }.run {
|
||||
assertSame<Any>(KryoException::class.java, javaClass)
|
||||
assertEquals(encodingNotPermittedFormat.format(compression), message)
|
||||
}
|
||||
}
|
||||
}
|
@ -69,6 +69,7 @@ class ListsSerializationTest {
|
||||
val serializedForm = emptyList<Int>().serialize()
|
||||
val output = ByteArrayOutputStream().apply {
|
||||
kryoMagic.writeTo(this)
|
||||
SectionId.ALT_DATA_AND_STOP.writeTo(this)
|
||||
write(DefaultClassResolver.NAME + 2)
|
||||
write(nameID)
|
||||
write(javaEmptyListClass.name.toAscii())
|
||||
|
@ -79,6 +79,7 @@ class MapsSerializationTest {
|
||||
val serializedForm = emptyMap<Int, Int>().serialize()
|
||||
val output = ByteArrayOutputStream().apply {
|
||||
kryoMagic.writeTo(this)
|
||||
SectionId.ALT_DATA_AND_STOP.writeTo(this)
|
||||
write(DefaultClassResolver.NAME + 2)
|
||||
write(nameID)
|
||||
write(javaEmptyMapClass.name.toAscii())
|
||||
|
@ -99,6 +99,7 @@ class SerializationTokenTest {
|
||||
val stream = ByteArrayOutputStream()
|
||||
Output(stream).use {
|
||||
kryoMagic.writeTo(it)
|
||||
SectionId.ALT_DATA_AND_STOP.writeTo(it)
|
||||
kryo.writeClass(it, SingletonSerializeAsToken::class.java)
|
||||
kryo.writeObject(it, emptyList<Any>())
|
||||
}
|
||||
|
@ -56,6 +56,7 @@ class SetsSerializationTest {
|
||||
val serializedForm = emptySet<Int>().serialize()
|
||||
val output = ByteArrayOutputStream().apply {
|
||||
kryoMagic.writeTo(this)
|
||||
SectionId.ALT_DATA_AND_STOP.writeTo(this)
|
||||
write(DefaultClassResolver.NAME + 2)
|
||||
write(nameID)
|
||||
write(javaEmptySetClass.name.toAscii())
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
package net.corda.nodeapi.internal.serialization.amqp
|
||||
|
||||
import com.nhaarman.mockito_kotlin.doReturn
|
||||
import com.nhaarman.mockito_kotlin.whenever
|
||||
import net.corda.client.rpc.RPCException
|
||||
import net.corda.core.CordaRuntimeException
|
||||
import net.corda.core.contracts.*
|
||||
@ -11,21 +13,16 @@ import net.corda.core.flows.FlowException
|
||||
import net.corda.core.identity.AbstractParty
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.internal.AbstractAttachment
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.serialization.MissingAttachmentsException
|
||||
import net.corda.core.serialization.SerializationFactory
|
||||
import net.corda.core.serialization.*
|
||||
import net.corda.core.transactions.LedgerTransaction
|
||||
import net.corda.core.utilities.OpaqueBytes
|
||||
import net.corda.nodeapi.internal.serialization.AllWhitelist
|
||||
import net.corda.nodeapi.internal.serialization.EmptyWhitelist
|
||||
import net.corda.nodeapi.internal.serialization.GeneratedAttachment
|
||||
import net.corda.nodeapi.internal.serialization.*
|
||||
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.isPrimitive
|
||||
import net.corda.nodeapi.internal.serialization.amqp.custom.BigDecimalSerializer
|
||||
import net.corda.nodeapi.internal.serialization.amqp.custom.CurrencySerializer
|
||||
import net.corda.testing.contracts.DummyContract
|
||||
import net.corda.testing.core.BOB_NAME
|
||||
import net.corda.testing.core.SerializationEnvironmentRule
|
||||
import net.corda.testing.core.TestIdentity
|
||||
import net.corda.testing.internal.rigorousMock
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.apache.qpid.proton.amqp.*
|
||||
import org.apache.qpid.proton.codec.DecoderImpl
|
||||
@ -35,22 +32,23 @@ import org.junit.Assert.*
|
||||
import org.junit.Ignore
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.junit.runners.Parameterized
|
||||
import org.junit.runners.Parameterized.Parameters
|
||||
import java.io.ByteArrayInputStream
|
||||
import java.io.IOException
|
||||
import java.io.NotSerializableException
|
||||
import java.lang.reflect.Type
|
||||
import java.math.BigDecimal
|
||||
import java.nio.ByteBuffer
|
||||
import java.time.*
|
||||
import java.time.temporal.ChronoUnit
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import kotlin.reflect.full.superclasses
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertNotNull
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class SerializationOutputTests {
|
||||
@RunWith(Parameterized::class)
|
||||
class SerializationOutputTests(private val compression: CordaSerializationEncoding?) {
|
||||
private companion object {
|
||||
val BOB_IDENTITY = TestIdentity(BOB_NAME, 80).identity
|
||||
val megaCorp = TestIdentity(CordaX500Name("MegaCorp", "London", "GB"))
|
||||
@ -59,6 +57,9 @@ class SerializationOutputTests {
|
||||
val MEGA_CORP_PUBKEY get() = megaCorp.publicKey
|
||||
val MINI_CORP get() = miniCorp.party
|
||||
val MINI_CORP_PUBKEY get() = miniCorp.publicKey
|
||||
@Parameters(name = "{0}")
|
||||
@JvmStatic
|
||||
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
|
||||
}
|
||||
|
||||
@Rule
|
||||
@ -173,16 +174,20 @@ class SerializationOutputTests {
|
||||
}
|
||||
}
|
||||
|
||||
private val encodingWhitelist = rigorousMock<EncodingWhitelist>().also {
|
||||
if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression)
|
||||
}
|
||||
|
||||
private fun defaultFactory() = SerializerFactory(
|
||||
AllWhitelist, ClassLoader.getSystemClassLoader(),
|
||||
EvolutionSerializerGetterTesting())
|
||||
|
||||
private inline fun <reified T : Any> serdes(obj: T,
|
||||
factory: SerializerFactory = SerializerFactory(
|
||||
AllWhitelist, ClassLoader.getSystemClassLoader(),
|
||||
EvolutionSerializerGetterTesting()),
|
||||
freshDeserializationFactory: SerializerFactory = SerializerFactory(
|
||||
AllWhitelist, ClassLoader.getSystemClassLoader(),
|
||||
EvolutionSerializerGetterTesting()),
|
||||
factory: SerializerFactory = defaultFactory(),
|
||||
freshDeserializationFactory: SerializerFactory = defaultFactory(),
|
||||
expectedEqual: Boolean = true,
|
||||
expectDeserializedEqual: Boolean = true): T {
|
||||
val ser = SerializationOutput(factory)
|
||||
val ser = SerializationOutput(factory, compression)
|
||||
val bytes = ser.serialize(obj)
|
||||
|
||||
val decoder = DecoderImpl().apply {
|
||||
@ -198,18 +203,19 @@ class SerializationOutputTests {
|
||||
this.register(TransformTypes.DESCRIPTOR, TransformTypes.Companion)
|
||||
}
|
||||
EncoderImpl(decoder)
|
||||
decoder.setByteBuffer(ByteBuffer.wrap(bytes.bytes, 8, bytes.size - 8))
|
||||
// Check that a vanilla AMQP decoder can deserialize without schema.
|
||||
val result = decoder.readObject() as Envelope
|
||||
assertNotNull(result)
|
||||
|
||||
val des = DeserializationInput(freshDeserializationFactory)
|
||||
DeserializationInput.withDataBytes(bytes, encodingWhitelist) {
|
||||
decoder.setByteBuffer(it)
|
||||
// Check that a vanilla AMQP decoder can deserialize without schema.
|
||||
val result = decoder.readObject() as Envelope
|
||||
assertNotNull(result)
|
||||
}
|
||||
val des = DeserializationInput(freshDeserializationFactory, encodingWhitelist)
|
||||
val desObj = des.deserialize(bytes)
|
||||
assertTrue(Objects.deepEquals(obj, desObj) == expectedEqual)
|
||||
|
||||
// Now repeat with a re-used factory
|
||||
val ser2 = SerializationOutput(factory)
|
||||
val des2 = DeserializationInput(factory)
|
||||
val ser2 = SerializationOutput(factory, compression)
|
||||
val des2 = DeserializationInput(factory, encodingWhitelist)
|
||||
val desObj2 = des2.deserialize(ser2.serialize(obj))
|
||||
assertTrue(Objects.deepEquals(obj, desObj2) == expectedEqual)
|
||||
assertTrue(Objects.deepEquals(desObj, desObj2) == expectDeserializedEqual)
|
||||
@ -432,9 +438,9 @@ class SerializationOutputTests {
|
||||
|
||||
@Test
|
||||
fun `class constructor is invoked on deserialisation`() {
|
||||
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()))
|
||||
val des = DeserializationInput(ser.serializerFactory)
|
||||
|
||||
compression == null || return // Manipulation of serialized bytes is invalid if they're compressed.
|
||||
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()), compression)
|
||||
val des = DeserializationInput(ser.serializerFactory, encodingWhitelist)
|
||||
val serialisedOne = ser.serialize(NonZeroByte(1)).bytes
|
||||
val serialisedTwo = ser.serialize(NonZeroByte(2)).bytes
|
||||
|
||||
@ -1116,6 +1122,29 @@ class SerializationOutputTests {
|
||||
val c = C(Amount<Currency>(100, BigDecimal("1.5"), Currency.getInstance("USD")))
|
||||
|
||||
// were the issue not fixed we'd blow up here
|
||||
SerializationOutput(factory).serialize(c)
|
||||
SerializationOutput(factory, compression).serialize(c)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `compression has the desired effect`() {
|
||||
compression ?: return
|
||||
val factory = defaultFactory()
|
||||
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
|
||||
val compressed = SerializationOutput(factory, compression).serialize(data)
|
||||
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
|
||||
assertArrayEquals(data, DeserializationInput(factory, encodingWhitelist).deserialize(compressed))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `a particular encoding can be banned for deserialization`() {
|
||||
compression ?: return
|
||||
val factory = defaultFactory()
|
||||
doReturn(false).whenever(encodingWhitelist).acceptEncoding(compression)
|
||||
val compressed = SerializationOutput(factory, compression).serialize("whatever")
|
||||
val input = DeserializationInput(factory, encodingWhitelist)
|
||||
catchThrowable { input.deserialize(compressed) }.run {
|
||||
assertSame(NotSerializableException::class.java, javaClass)
|
||||
assertEquals(encodingNotPermittedFormat.format(compression), message)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,12 +1,16 @@
|
||||
package net.corda.nodeapi.internal.serialization.kryo
|
||||
|
||||
import net.corda.core.internal.declaredField
|
||||
import org.assertj.core.api.Assertions.catchThrowable
|
||||
import org.junit.Assert.assertArrayEquals
|
||||
import org.junit.Test
|
||||
import java.io.*
|
||||
import java.nio.BufferOverflowException
|
||||
import java.util.*
|
||||
import java.util.zip.DeflaterOutputStream
|
||||
import java.util.zip.InflaterInputStream
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertSame
|
||||
|
||||
class KryoStreamsTest {
|
||||
class NegOutputStream(private val stream: OutputStream) : OutputStream() {
|
||||
@ -57,4 +61,37 @@ class KryoStreamsTest {
|
||||
assertEquals(-1, read())
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `ByteBufferOutputStream works`() {
|
||||
val stream = ByteBufferOutputStream(3)
|
||||
stream.write("abc".toByteArray())
|
||||
val getBuf = stream.declaredField<ByteArray>(ByteArrayOutputStream::class, "buf")::value
|
||||
assertEquals(3, getBuf().size)
|
||||
repeat(2) {
|
||||
assertSame<Any>(BufferOverflowException::class.java, catchThrowable {
|
||||
stream.alsoAsByteBuffer(9) {
|
||||
it.put("0123456789".toByteArray())
|
||||
}
|
||||
}.javaClass)
|
||||
assertEquals(3 + 9, getBuf().size)
|
||||
}
|
||||
// This time make too much space:
|
||||
stream.alsoAsByteBuffer(11) {
|
||||
it.put("0123456789".toByteArray())
|
||||
}
|
||||
stream.write("def".toByteArray())
|
||||
assertArrayEquals("abc0123456789def".toByteArray(), stream.toByteArray())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `ByteBufferOutputStream discards data after final position`() {
|
||||
val stream = ByteBufferOutputStream(0)
|
||||
stream.alsoAsByteBuffer(10) {
|
||||
it.put("0123456789".toByteArray())
|
||||
it.position(5)
|
||||
}
|
||||
stream.write("def".toByteArray())
|
||||
assertArrayEquals("01234def".toByteArray(), stream.toByteArray())
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user