diff --git a/.ci/api-current.txt b/.ci/api-current.txt index af70ca2c78..dee3db0766 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -2817,7 +2817,7 @@ public final class net.corda.core.serialization.ObjectWithCompatibleContext exte public final class net.corda.core.serialization.SerializationAPIKt extends java.lang.Object @org.jetbrains.annotations.NotNull public static final net.corda.core.serialization.SerializedBytes serialize(Object, net.corda.core.serialization.SerializationFactory, net.corda.core.serialization.SerializationContext) ## -public interface net.corda.core.serialization.SerializationContext +@net.corda.core.DoNotImplement public interface net.corda.core.serialization.SerializationContext @org.jetbrains.annotations.NotNull public abstract ClassLoader getDeserializationClassLoader() public abstract boolean getObjectReferencesEnabled() @org.jetbrains.annotations.NotNull public abstract net.corda.core.utilities.ByteSequence getPreferredSerializationVersion() diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt index 82ee0d402a..b8e90fc8c2 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt @@ -1,5 +1,6 @@ package net.corda.core.serialization +import net.corda.core.DoNotImplement import net.corda.core.crypto.SecureHash import net.corda.core.crypto.sha256 import net.corda.core.serialization.internal.effectiveSerializationEnv @@ -99,14 +100,22 @@ abstract class SerializationFactory { } } typealias SerializationMagic = ByteSequence +@DoNotImplement +interface SerializationEncoding + /** * Parameters to serialization and deserialization. */ +@DoNotImplement interface SerializationContext { /** * When serializing, use the format this header sequence represents. */ val preferredSerializationVersion: SerializationMagic + /** + * If non-null, apply this encoding (typically compression) when serializing. + */ + val encoding: SerializationEncoding? /** * The class loader to use for deserialization. */ @@ -115,6 +124,10 @@ interface SerializationContext { * A whitelist that contains (mostly for security purposes) which classes can be serialized and deserialized. */ val whitelist: ClassWhitelist + /** + * A whitelist that determines (mostly for security purposes) whether a particular encoding may be used when deserializing. + */ + val encodingWhitelist: EncodingWhitelist /** * A map of any addition properties specific to the particular use case. */ @@ -161,6 +174,11 @@ interface SerializationContext { */ fun withPreferredSerializationVersion(magic: SerializationMagic): SerializationContext + /** + * A shallow copy of this context but with the given (possibly null) encoding. + */ + fun withEncoding(encoding: SerializationEncoding?): SerializationContext + /** * The use case that we are serializing for, since it influences the implementations chosen. */ @@ -232,3 +250,8 @@ class SerializedBytes(bytes: ByteArray) : OpaqueBytes(bytes) { interface ClassWhitelist { fun hasListed(type: Class<*>): Boolean } + +@DoNotImplement +interface EncodingWhitelist { + fun acceptEncoding(encoding: SerializationEncoding): Boolean +} diff --git a/node-api/build.gradle b/node-api/build.gradle index d2c51dde94..bf674d7b32 100644 --- a/node-api/build.gradle +++ b/node-api/build.gradle @@ -34,6 +34,9 @@ dependencies { // For AMQP serialisation. compile "org.apache.qpid:proton-j:0.21.0" + // Pure-Java Snappy compression + compile 'org.iq80.snappy:snappy:0.4' + // Unit testing helpers. testCompile "junit:junit:$junit_version" testCompile "org.assertj:assertj-core:$assertj_version" diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/ClientContexts.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/ClientContexts.kt index 3c2eb8fa76..e4e2f53417 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/ClientContexts.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/ClientContexts.kt @@ -18,10 +18,12 @@ val KRYO_RPC_CLIENT_CONTEXT = SerializationContextImpl(kryoMagic, GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), emptyMap(), true, - SerializationContext.UseCase.RPCClient) + SerializationContext.UseCase.RPCClient, + null) val AMQP_RPC_CLIENT_CONTEXT = SerializationContextImpl(amqpMagic, SerializationDefaults.javaClass.classLoader, GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), emptyMap(), true, - SerializationContext.UseCase.RPCClient) + SerializationContext.UseCase.RPCClient, + null) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/OrdinalIO.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/OrdinalIO.kt new file mode 100644 index 0000000000..6e04d490f4 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/OrdinalIO.kt @@ -0,0 +1,31 @@ +package net.corda.nodeapi.internal.serialization + +import java.io.EOFException +import java.io.InputStream +import java.io.OutputStream +import java.nio.ByteBuffer + +class OrdinalBits(private val ordinal: Int) { + interface OrdinalWriter { + val bits: OrdinalBits + val encodedSize get() = 1 + fun writeTo(stream: OutputStream) = stream.write(bits.ordinal) + fun putTo(buffer: ByteBuffer) = buffer.put(bits.ordinal.toByte())!! + } + + init { + require(ordinal >= 0) { "The ordinal must be non-negative." } + require(ordinal < 128) { "Consider implementing a varint encoding." } + } +} + +class OrdinalReader(private val values: Array) { + private val enumName = values[0].javaClass.simpleName + private val range = 0 until values.size + fun readFrom(stream: InputStream): E { + val ordinal = stream.read() + if (ordinal == -1) throw EOFException("Expected a $enumName ordinal.") + if (ordinal !in range) throw NoSuchElementException("No $enumName with ordinal: $ordinal") + return values[ordinal] + } +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationFormat.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationFormat.kt index 6414efbb17..fefdfb930f 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationFormat.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationFormat.kt @@ -1,8 +1,17 @@ package net.corda.nodeapi.internal.serialization +import net.corda.core.internal.VisibleForTesting +import net.corda.core.serialization.SerializationEncoding import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.OpaqueBytes +import net.corda.nodeapi.internal.serialization.OrdinalBits.OrdinalWriter +import org.iq80.snappy.SnappyFramedInputStream +import org.iq80.snappy.SnappyFramedOutputStream +import java.io.OutputStream +import java.io.InputStream import java.nio.ByteBuffer +import java.util.zip.DeflaterOutputStream +import java.util.zip.InflaterInputStream class CordaSerializationMagic(bytes: ByteArray) : OpaqueBytes(bytes) { private val bufferView = slice() @@ -10,3 +19,40 @@ class CordaSerializationMagic(bytes: ByteArray) : OpaqueBytes(bytes) { return if (data.slice(end = size) == bufferView) data.slice(size) else null } } + +enum class SectionId : OrdinalWriter { + /** Serialization data follows, and then discard the rest of the stream (if any) as legacy data may have trailing garbage. */ + DATA_AND_STOP, + /** Identical behaviour to [DATA_AND_STOP], historically used for Kryo. Do not use in new code. */ + ALT_DATA_AND_STOP, + /** The ordinal of a [CordaSerializationEncoding] follows, which should be used to decode the remainder of the stream. */ + ENCODING; + + companion object { + val reader = OrdinalReader(values()) + } + + override val bits = OrdinalBits(ordinal) +} + +enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter { + DEFLATE { + override fun wrap(stream: OutputStream) = DeflaterOutputStream(stream) + override fun wrap(stream: InputStream) = InflaterInputStream(stream) + }, + SNAPPY { + override fun wrap(stream: OutputStream) = SnappyFramedOutputStream(stream) + override fun wrap(stream: InputStream) = SnappyFramedInputStream(stream, false) + }; + + companion object { + val reader = OrdinalReader(values()) + } + + override val bits = OrdinalBits(ordinal) + abstract fun wrap(stream: OutputStream): OutputStream + abstract fun wrap(stream: InputStream): InputStream +} + +@VisibleForTesting +internal val encodingNotPermittedFormat = "Encoding not permitted: %s" diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt index 95dcc9b603..a093bc871b 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt @@ -18,13 +18,18 @@ import java.util.concurrent.ExecutionException val attachmentsClassLoaderEnabledPropertyName = "attachments.class.loader.enabled" -data class SerializationContextImpl(override val preferredSerializationVersion: SerializationMagic, - override val deserializationClassLoader: ClassLoader, - override val whitelist: ClassWhitelist, - override val properties: Map, - override val objectReferencesEnabled: Boolean, - override val useCase: SerializationContext.UseCase) : SerializationContext { +internal object NullEncodingWhitelist : EncodingWhitelist { + override fun acceptEncoding(encoding: SerializationEncoding) = false +} +data class SerializationContextImpl @JvmOverloads constructor(override val preferredSerializationVersion: SerializationMagic, + override val deserializationClassLoader: ClassLoader, + override val whitelist: ClassWhitelist, + override val properties: Map, + override val objectReferencesEnabled: Boolean, + override val useCase: SerializationContext.UseCase, + override val encoding: SerializationEncoding?, + override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) : SerializationContext { private val cache: Cache, AttachmentsClassLoader> = CacheBuilder.newBuilder().weakValues().maximumSize(1024).build() /** @@ -70,6 +75,7 @@ data class SerializationContextImpl(override val preferredSerializationVersion: } override fun withPreferredSerializationVersion(magic: SerializationMagic) = copy(preferredSerializationVersion = magic) + override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding) } open class SerializationFactoryImpl : SerializationFactory() { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/ServerContexts.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/ServerContexts.kt index a2bfa64628..cc8dcfa305 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/ServerContexts.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/ServerContexts.kt @@ -27,22 +27,26 @@ val KRYO_RPC_SERVER_CONTEXT = SerializationContextImpl(kryoMagic, GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), emptyMap(), true, - SerializationContext.UseCase.RPCServer) + SerializationContext.UseCase.RPCServer, + null) val KRYO_STORAGE_CONTEXT = SerializationContextImpl(kryoMagic, SerializationDefaults.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, - SerializationContext.UseCase.Storage) + SerializationContext.UseCase.Storage, + null) val AMQP_STORAGE_CONTEXT = SerializationContextImpl(amqpMagic, SerializationDefaults.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, - SerializationContext.UseCase.Storage) + SerializationContext.UseCase.Storage, + null) val AMQP_RPC_SERVER_CONTEXT = SerializationContextImpl(amqpMagic, SerializationDefaults.javaClass.classLoader, GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), emptyMap(), true, - SerializationContext.UseCase.RPCServer) + SerializationContext.UseCase.RPCServer, + null) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SharedContexts.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SharedContexts.kt index 25e4e278a1..9620b3c999 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SharedContexts.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SharedContexts.kt @@ -20,18 +20,19 @@ val KRYO_P2P_CONTEXT = SerializationContextImpl(kryoMagic, GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), emptyMap(), true, - SerializationContext.UseCase.P2P) + SerializationContext.UseCase.P2P, + null) val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl(kryoMagic, SerializationDefaults.javaClass.classLoader, QuasarWhitelist, emptyMap(), true, - SerializationContext.UseCase.Checkpoint) + SerializationContext.UseCase.Checkpoint, + null) val AMQP_P2P_CONTEXT = SerializationContextImpl(amqpMagic, SerializationDefaults.javaClass.classLoader, GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), emptyMap(), true, - SerializationContext.UseCase.P2P) - - + SerializationContext.UseCase.P2P, + null) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPStreams.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPStreams.kt new file mode 100644 index 0000000000..f45ac6d864 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPStreams.kt @@ -0,0 +1,31 @@ +package net.corda.nodeapi.internal.serialization.amqp + +import com.esotericsoftware.kryo.io.ByteBufferInputStream +import net.corda.nodeapi.internal.serialization.kryo.ByteBufferOutputStream +import net.corda.nodeapi.internal.serialization.kryo.serializeOutputStreamPool +import java.io.InputStream +import java.io.OutputStream +import java.nio.ByteBuffer + +fun InputStream.asByteBuffer(): ByteBuffer { + return if (this is ByteBufferInputStream) { + byteBuffer // BBIS has no other state, so this is perfectly safe. + } else { + ByteBuffer.wrap(serializeOutputStreamPool.run { + copyTo(it) + it.toByteArray() + }) + } +} + +fun OutputStream.alsoAsByteBuffer(remaining: Int, task: (ByteBuffer) -> T): T { + return if (this is ByteBufferOutputStream) { + alsoAsByteBuffer(remaining, task) + } else { + serializeOutputStreamPool.run { + val result = it.alsoAsByteBuffer(remaining, task) + it.copyTo(this) + result + } + } +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt index da555798e1..cdde047ef8 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt @@ -1,18 +1,27 @@ package net.corda.nodeapi.internal.serialization.amqp +import com.esotericsoftware.kryo.io.ByteBufferInputStream +import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.getStackTraceAsString +import net.corda.core.serialization.EncodingWhitelist import net.corda.core.serialization.SerializedBytes import net.corda.core.utilities.ByteSequence +import net.corda.nodeapi.internal.serialization.CordaSerializationEncoding +import net.corda.nodeapi.internal.serialization.NullEncodingWhitelist +import net.corda.nodeapi.internal.serialization.SectionId +import net.corda.nodeapi.internal.serialization.encodingNotPermittedFormat import org.apache.qpid.proton.amqp.Binary import org.apache.qpid.proton.amqp.DescribedType import org.apache.qpid.proton.amqp.UnsignedByte import org.apache.qpid.proton.amqp.UnsignedInteger import org.apache.qpid.proton.codec.Data +import java.io.InputStream import java.io.NotSerializableException import java.lang.reflect.ParameterizedType import java.lang.reflect.Type import java.lang.reflect.TypeVariable import java.lang.reflect.WildcardType +import java.nio.ByteBuffer data class ObjectAndEnvelope(val obj: T, val envelope: Envelope) @@ -22,7 +31,8 @@ data class ObjectAndEnvelope(val obj: T, val envelope: Envelope) * @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple * instances and threads. */ -class DeserializationInput(internal val serializerFactory: SerializerFactory) { +class DeserializationInput @JvmOverloads constructor(private val serializerFactory: SerializerFactory, + private val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) { private val objectHistory: MutableList = mutableListOf() internal companion object { @@ -47,6 +57,28 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) { } return size + BYTES_NEEDED_TO_PEEK } + + @VisibleForTesting + @Throws(NotSerializableException::class) + internal fun withDataBytes(byteSequence: ByteSequence, encodingWhitelist: EncodingWhitelist, task: (ByteBuffer) -> T): T { + // Check that the lead bytes match expected header + val amqpSequence = amqpMagic.consume(byteSequence) ?: throw NotSerializableException("Serialization header does not match.") + var stream: InputStream = ByteBufferInputStream(amqpSequence) + try { + while (true) { + when (SectionId.reader.readFrom(stream)) { + SectionId.ENCODING -> { + val encoding = CordaSerializationEncoding.reader.readFrom(stream) + encodingWhitelist.acceptEncoding(encoding) || throw NotSerializableException(encodingNotPermittedFormat.format(encoding)) + stream = encoding.wrap(stream) + } + SectionId.DATA_AND_STOP, SectionId.ALT_DATA_AND_STOP -> return task(stream.asByteBuffer()) + } + } + } finally { + stream.close() + } + } } @Throws(NotSerializableException::class) @@ -58,12 +90,12 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) { @Throws(NotSerializableException::class) internal fun getEnvelope(byteSequence: ByteSequence): Envelope { - // Check that the lead bytes match expected header - val dataBytes = amqpMagic.consume(byteSequence) ?: throw NotSerializableException("Serialization header does not match.") - val data = Data.Factory.create() - val expectedSize = dataBytes.remaining() - if (data.decode(dataBytes) != expectedSize.toLong()) throw NotSerializableException("Unexpected size of data") - return Envelope.get(data) + return withDataBytes(byteSequence, encodingWhitelist) { dataBytes -> + val data = Data.Factory.create() + val expectedSize = dataBytes.remaining() + if (data.decode(dataBytes) != expectedSize.toLong()) throw NotSerializableException("Unexpected size of data") + Envelope.get(data) + } } @Throws(NotSerializableException::class) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/Schema.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/Schema.kt index 285a1f5d54..1318d066fe 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/Schema.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/Schema.kt @@ -12,7 +12,7 @@ import net.corda.nodeapi.internal.serialization.carpenter.Field as CarpenterFiel import net.corda.nodeapi.internal.serialization.carpenter.Schema as CarpenterSchema const val DESCRIPTOR_DOMAIN: String = "net.corda" -val amqpMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(1, 0, 0)) +val amqpMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(1, 0)) /** * This and the classes below are OO representations of the AMQP XML schema described in the specification. Their diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutput.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutput.kt index f6b9972ec7..1dcf750ef5 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutput.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutput.kt @@ -1,10 +1,14 @@ package net.corda.nodeapi.internal.serialization.amqp +import net.corda.core.serialization.SerializationEncoding import net.corda.core.serialization.SerializedBytes +import net.corda.nodeapi.internal.serialization.CordaSerializationEncoding +import net.corda.nodeapi.internal.serialization.SectionId +import net.corda.nodeapi.internal.serialization.kryo.byteArrayOutput import org.apache.qpid.proton.codec.Data import java.io.NotSerializableException +import java.io.OutputStream import java.lang.reflect.Type -import java.nio.ByteBuffer import java.util.* import kotlin.collections.LinkedHashSet @@ -19,8 +23,7 @@ data class BytesAndSchemas( * @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple * instances and threads. */ -open class SerializationOutput(internal val serializerFactory: SerializerFactory) { - +open class SerializationOutput @JvmOverloads constructor(internal val serializerFactory: SerializerFactory, private val encoding: SerializationEncoding? = null) { private val objectHistory: MutableMap = IdentityHashMap() private val serializerHistory: MutableSet> = LinkedHashSet() internal val schemaHistory: MutableSet = LinkedHashSet() @@ -67,11 +70,21 @@ open class SerializationOutput(internal val serializerFactory: SerializerFactory writeTransformSchema(TransformsSchema.build(schema, serializerFactory), this) } } - val bytes = ByteArray(data.encodedSize().toInt() + 8) - val buf = ByteBuffer.wrap(bytes) - amqpMagic.putTo(buf) - data.encode(buf) - return SerializedBytes(bytes) + return SerializedBytes(byteArrayOutput { + var stream: OutputStream = it + try { + amqpMagic.writeTo(stream) + if (encoding != null) { + SectionId.ENCODING.writeTo(stream) + (encoding as CordaSerializationEncoding).writeTo(stream) + stream = encoding.wrap(stream) + } + SectionId.DATA_AND_STOP.writeTo(stream) + stream.alsoAsByteBuffer(data.encodedSize().toInt(), data::encode) + } finally { + stream.close() + } + }) } internal fun writeObject(obj: Any, data: Data) { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoSerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoSerializationScheme.kt index 719d982db3..7e1b94fffd 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoSerializationScheme.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoSerializationScheme.kt @@ -16,10 +16,12 @@ import net.corda.core.utilities.ByteSequence import net.corda.core.serialization.* import net.corda.nodeapi.internal.serialization.CordaSerializationMagic import net.corda.nodeapi.internal.serialization.CordaClassResolver +import net.corda.nodeapi.internal.serialization.SectionId import net.corda.nodeapi.internal.serialization.SerializationScheme +import net.corda.nodeapi.internal.serialization.* import java.security.PublicKey -val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0, 1)) +val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0)) private object AutoCloseableSerialisationDetector : Serializer() { override fun write(kryo: Kryo, output: Output, closeable: AutoCloseable) { @@ -87,11 +89,25 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme { val dataBytes = kryoMagic.consume(byteSequence) ?: throw KryoException("Serialized bytes header does not match expected format.") return context.kryo { kryoInput(ByteBufferInputStream(dataBytes)) { - if (context.objectReferencesEnabled) { - uncheckedCast(readClassAndObject(this)) - } else { - withoutReferences { uncheckedCast(readClassAndObject(this)) } + val result: T + loop@ while (true) { + when (SectionId.reader.readFrom(this)) { + SectionId.ENCODING -> { + val encoding = CordaSerializationEncoding.reader.readFrom(this) + context.encodingWhitelist.acceptEncoding(encoding) || throw KryoException(encodingNotPermittedFormat.format(encoding)) + substitute(encoding::wrap) + } + SectionId.DATA_AND_STOP, SectionId.ALT_DATA_AND_STOP -> { + result = if (context.objectReferencesEnabled) { + uncheckedCast(readClassAndObject(this)) + } else { + withoutReferences { uncheckedCast(readClassAndObject(this)) } + } + break@loop + } + } } + result } } } @@ -100,6 +116,12 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme { return context.kryo { SerializedBytes(kryoOutput { kryoMagic.writeTo(this) + context.encoding?.let { encoding -> + SectionId.ENCODING.writeTo(this) + (encoding as CordaSerializationEncoding).writeTo(this) + substitute(encoding::wrap) + } + SectionId.ALT_DATA_AND_STOP.writeTo(this) // Forward-compatible in null-encoding case. if (context.objectReferencesEnabled) { writeClassAndObject(this, obj) } else { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreams.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreams.kt index 9a34131a30..b1274223cc 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreams.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreams.kt @@ -4,13 +4,34 @@ import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output import net.corda.core.internal.LazyPool import java.io.* +import java.nio.ByteBuffer + +class ByteBufferOutputStream(size: Int) : ByteArrayOutputStream(size) { + companion object { + private val ensureCapacity = ByteArrayOutputStream::class.java.getDeclaredMethod("ensureCapacity", Int::class.java).apply { + isAccessible = true + } + } + + fun alsoAsByteBuffer(remaining: Int, task: (ByteBuffer) -> T): T { + ensureCapacity.invoke(this, count + remaining) + val buffer = ByteBuffer.wrap(buf, count, remaining) + val result = task(buffer) + count = buffer.position() + return result + } + + fun copyTo(stream: OutputStream) { + stream.write(buf, 0, count) + } +} private val serializationBufferPool = LazyPool( newInstance = { ByteArray(64 * 1024) }) -private val serializeOutputStreamPool = LazyPool( - clear = ByteArrayOutputStream::reset, +internal val serializeOutputStreamPool = LazyPool( + clear = ByteBufferOutputStream::reset, shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large - newInstance = { ByteArrayOutputStream(64 * 1024) }) + newInstance = { ByteBufferOutputStream(64 * 1024) }) internal fun kryoInput(underlying: InputStream, task: Input.() -> T): T { return serializationBufferPool.run { @@ -22,13 +43,19 @@ internal fun kryoInput(underlying: InputStream, task: Input.() -> T): T { } internal fun kryoOutput(task: Output.() -> T): ByteArray { - return serializeOutputStreamPool.run { underlying -> + return byteArrayOutput { underlying -> serializationBufferPool.run { Output(it).use { output -> output.outputStream = underlying output.task() } } + } +} + +internal fun byteArrayOutput(task: (ByteBufferOutputStream) -> T): ByteArray { + return serializeOutputStreamPool.run { underlying -> + task(underlying) underlying.toByteArray() // Must happen after close, to allow ZIP footer to be written for example. } } diff --git a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/ForbiddenLambdaSerializationTests.java b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/ForbiddenLambdaSerializationTests.java index 265cbf098b..123bf60e0f 100644 --- a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/ForbiddenLambdaSerializationTests.java +++ b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/ForbiddenLambdaSerializationTests.java @@ -33,7 +33,7 @@ public final class ForbiddenLambdaSerializationTests { EnumSet contexts = EnumSet.complementOf(EnumSet.of(SerializationContext.UseCase.Checkpoint)); contexts.forEach(ctx -> { - SerializationContext context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx); + SerializationContext context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx, null); String value = "Hey"; Callable target = (Callable & Serializable) () -> value; @@ -55,7 +55,7 @@ public final class ForbiddenLambdaSerializationTests { EnumSet contexts = EnumSet.complementOf(EnumSet.of(SerializationContext.UseCase.Checkpoint)); contexts.forEach(ctx -> { - SerializationContext context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx); + SerializationContext context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx, null); String value = "Hey"; Callable target = () -> value; diff --git a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/LambdaCheckpointSerializationTest.java b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/LambdaCheckpointSerializationTest.java index 0203c498f4..6482240ba9 100644 --- a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/LambdaCheckpointSerializationTest.java +++ b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/LambdaCheckpointSerializationTest.java @@ -26,7 +26,7 @@ public final class LambdaCheckpointSerializationTest { @Before public void setup() { factory = testSerialization.getSerializationFactory(); - context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, SerializationContext.UseCase.Checkpoint); + context = new SerializationContextImpl(KryoSerializationSchemeKt.getKryoMagic(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, SerializationContext.UseCase.Checkpoint, null); } @Test diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/X509UtilitiesTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/X509UtilitiesTest.kt index 2f3e7420fd..8e5aa1c776 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/X509UtilitiesTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/X509UtilitiesTest.kt @@ -319,7 +319,8 @@ class X509UtilitiesTest { AllWhitelist, emptyMap(), true, - SerializationContext.UseCase.P2P) + SerializationContext.UseCase.P2P, + null) val expected = X509Utilities.createSelfSignedCACertificate(ALICE.name.x500Principal, Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)) val serialized = expected.serialize(factory, context).bytes val actual = serialized.deserialize(factory, context) @@ -334,7 +335,8 @@ class X509UtilitiesTest { AllWhitelist, emptyMap(), true, - SerializationContext.UseCase.P2P) + SerializationContext.UseCase.P2P, + null) val rootCAKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) val rootCACert = X509Utilities.createSelfSignedCACertificate(ALICE_NAME.x500Principal, rootCAKey) val certificate = X509Utilities.createCertificate(CertificateType.TLS, rootCACert, rootCAKey, BOB_NAME.x500Principal, BOB.publicKey) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolverTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolverTests.kt index a66f3038b4..87157fdaa9 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolverTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolverTests.kt @@ -108,8 +108,8 @@ class CordaClassResolverTests { val emptyMapClass = mapOf().javaClass } - private val emptyWhitelistContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P) - private val allButBlacklistedContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P) + private val emptyWhitelistContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P, null) + private val allButBlacklistedContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P, null) @Test fun `Annotation on enum works for specialised entries`() { CordaClassResolver(emptyWhitelistContext).getRegistration(Foo.Bar::class.java) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/KryoTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/KryoTests.kt index 5d54a38fbe..7150e8c566 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/KryoTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/KryoTests.kt @@ -1,10 +1,13 @@ package net.corda.nodeapi.internal.serialization import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.KryoSerializable import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output import com.google.common.primitives.Ints +import com.nhaarman.mockito_kotlin.doReturn +import com.nhaarman.mockito_kotlin.whenever import net.corda.core.contracts.PrivacySalt import net.corda.core.crypto.* import net.corda.core.internal.FetchDataFlow @@ -16,24 +19,29 @@ import net.corda.node.services.persistence.NodeAttachmentService import net.corda.nodeapi.internal.serialization.kryo.kryoMagic import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.TestIdentity -import org.assertj.core.api.Assertions.assertThat -import org.assertj.core.api.Assertions.assertThatThrownBy +import net.corda.testing.internal.rigorousMock +import org.assertj.core.api.Assertions.* import org.junit.Assert.assertArrayEquals +import org.junit.Assert.assertEquals import org.junit.Before import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters import org.slf4j.LoggerFactory import java.io.ByteArrayInputStream import java.io.InputStream import java.time.Instant import java.util.* -import kotlin.test.assertEquals -import kotlin.test.assertFalse -import kotlin.test.assertNotNull -import kotlin.test.assertTrue +import kotlin.test.* -class KryoTests { +@RunWith(Parameterized::class) +class KryoTests(private val compression: CordaSerializationEncoding?) { companion object { private val ALICE_PUBKEY = TestIdentity(ALICE_NAME, 70).publicKey + @Parameters(name = "{0}") + @JvmStatic + fun compression() = arrayOf(null) + CordaSerializationEncoding.values() } private lateinit var factory: SerializationFactory @@ -47,7 +55,11 @@ class KryoTests { AllWhitelist, emptyMap(), true, - SerializationContext.UseCase.Storage) + SerializationContext.UseCase.Storage, + compression, + rigorousMock().also { + if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression) + }) } @Test @@ -259,7 +271,8 @@ class KryoTests { AllWhitelist, emptyMap(), true, - SerializationContext.UseCase.P2P) + SerializationContext.UseCase.P2P, + null) pt.serialize(factory, context) } @@ -300,4 +313,24 @@ class KryoTests { val exception2 = exception.serialize(factory, context).deserialize(factory, context) assertEquals(randomHash, exception2.requested) } + + @Test + fun `compression has the desired effect`() { + compression ?: return + val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it } + val compressed = data.serialize(factory, context) + assertEquals(.5, compressed.size.toDouble() / data.size, .03) + assertArrayEquals(data, compressed.deserialize(factory, context)) + } + + @Test + fun `a particular encoding can be banned for deserialization`() { + compression ?: return + doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression) + val compressed = "whatever".serialize(factory, context) + catchThrowable { compressed.deserialize(factory, context) }.run { + assertSame(KryoException::class.java, javaClass) + assertEquals(encodingNotPermittedFormat.format(compression), message) + } + } } \ No newline at end of file diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt index dc53b8fc29..7e1ffac95a 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt @@ -69,6 +69,7 @@ class ListsSerializationTest { val serializedForm = emptyList().serialize() val output = ByteArrayOutputStream().apply { kryoMagic.writeTo(this) + SectionId.ALT_DATA_AND_STOP.writeTo(this) write(DefaultClassResolver.NAME + 2) write(nameID) write(javaEmptyListClass.name.toAscii()) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt index a76bb8a52e..8efb66fffd 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt @@ -79,6 +79,7 @@ class MapsSerializationTest { val serializedForm = emptyMap().serialize() val output = ByteArrayOutputStream().apply { kryoMagic.writeTo(this) + SectionId.ALT_DATA_AND_STOP.writeTo(this) write(DefaultClassResolver.NAME + 2) write(nameID) write(javaEmptyMapClass.name.toAscii()) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt index 17495cb360..06a0f86d35 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt @@ -99,6 +99,7 @@ class SerializationTokenTest { val stream = ByteArrayOutputStream() Output(stream).use { kryoMagic.writeTo(it) + SectionId.ALT_DATA_AND_STOP.writeTo(it) kryo.writeClass(it, SingletonSerializeAsToken::class.java) kryo.writeObject(it, emptyList()) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt index edd1eabf58..7d4a352323 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt @@ -56,6 +56,7 @@ class SetsSerializationTest { val serializedForm = emptySet().serialize() val output = ByteArrayOutputStream().apply { kryoMagic.writeTo(this) + SectionId.ALT_DATA_AND_STOP.writeTo(this) write(DefaultClassResolver.NAME + 2) write(nameID) write(javaEmptySetClass.name.toAscii()) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutputTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutputTests.kt index 89b2dcbad4..6d20000170 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutputTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutputTests.kt @@ -2,6 +2,8 @@ package net.corda.nodeapi.internal.serialization.amqp +import com.nhaarman.mockito_kotlin.doReturn +import com.nhaarman.mockito_kotlin.whenever import net.corda.client.rpc.RPCException import net.corda.core.CordaRuntimeException import net.corda.core.contracts.* @@ -11,21 +13,16 @@ import net.corda.core.flows.FlowException import net.corda.core.identity.AbstractParty import net.corda.core.identity.CordaX500Name import net.corda.core.internal.AbstractAttachment -import net.corda.core.serialization.CordaSerializable -import net.corda.core.serialization.MissingAttachmentsException -import net.corda.core.serialization.SerializationFactory +import net.corda.core.serialization.* import net.corda.core.transactions.LedgerTransaction import net.corda.core.utilities.OpaqueBytes -import net.corda.nodeapi.internal.serialization.AllWhitelist -import net.corda.nodeapi.internal.serialization.EmptyWhitelist -import net.corda.nodeapi.internal.serialization.GeneratedAttachment +import net.corda.nodeapi.internal.serialization.* import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.isPrimitive -import net.corda.nodeapi.internal.serialization.amqp.custom.BigDecimalSerializer -import net.corda.nodeapi.internal.serialization.amqp.custom.CurrencySerializer import net.corda.testing.contracts.DummyContract import net.corda.testing.core.BOB_NAME import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.TestIdentity +import net.corda.testing.internal.rigorousMock import org.apache.activemq.artemis.api.core.SimpleString import org.apache.qpid.proton.amqp.* import org.apache.qpid.proton.codec.DecoderImpl @@ -35,22 +32,23 @@ import org.junit.Assert.* import org.junit.Ignore import org.junit.Rule import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters import java.io.ByteArrayInputStream import java.io.IOException import java.io.NotSerializableException -import java.lang.reflect.Type import java.math.BigDecimal -import java.nio.ByteBuffer import java.time.* import java.time.temporal.ChronoUnit import java.util.* -import java.util.concurrent.ConcurrentHashMap import kotlin.reflect.full.superclasses import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue -class SerializationOutputTests { +@RunWith(Parameterized::class) +class SerializationOutputTests(private val compression: CordaSerializationEncoding?) { private companion object { val BOB_IDENTITY = TestIdentity(BOB_NAME, 80).identity val megaCorp = TestIdentity(CordaX500Name("MegaCorp", "London", "GB")) @@ -59,6 +57,9 @@ class SerializationOutputTests { val MEGA_CORP_PUBKEY get() = megaCorp.publicKey val MINI_CORP get() = miniCorp.party val MINI_CORP_PUBKEY get() = miniCorp.publicKey + @Parameters(name = "{0}") + @JvmStatic + fun compression() = arrayOf(null) + CordaSerializationEncoding.values() } @Rule @@ -173,16 +174,20 @@ class SerializationOutputTests { } } + private val encodingWhitelist = rigorousMock().also { + if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression) + } + + private fun defaultFactory() = SerializerFactory( + AllWhitelist, ClassLoader.getSystemClassLoader(), + EvolutionSerializerGetterTesting()) + private inline fun serdes(obj: T, - factory: SerializerFactory = SerializerFactory( - AllWhitelist, ClassLoader.getSystemClassLoader(), - EvolutionSerializerGetterTesting()), - freshDeserializationFactory: SerializerFactory = SerializerFactory( - AllWhitelist, ClassLoader.getSystemClassLoader(), - EvolutionSerializerGetterTesting()), + factory: SerializerFactory = defaultFactory(), + freshDeserializationFactory: SerializerFactory = defaultFactory(), expectedEqual: Boolean = true, expectDeserializedEqual: Boolean = true): T { - val ser = SerializationOutput(factory) + val ser = SerializationOutput(factory, compression) val bytes = ser.serialize(obj) val decoder = DecoderImpl().apply { @@ -198,18 +203,19 @@ class SerializationOutputTests { this.register(TransformTypes.DESCRIPTOR, TransformTypes.Companion) } EncoderImpl(decoder) - decoder.setByteBuffer(ByteBuffer.wrap(bytes.bytes, 8, bytes.size - 8)) - // Check that a vanilla AMQP decoder can deserialize without schema. - val result = decoder.readObject() as Envelope - assertNotNull(result) - - val des = DeserializationInput(freshDeserializationFactory) + DeserializationInput.withDataBytes(bytes, encodingWhitelist) { + decoder.setByteBuffer(it) + // Check that a vanilla AMQP decoder can deserialize without schema. + val result = decoder.readObject() as Envelope + assertNotNull(result) + } + val des = DeserializationInput(freshDeserializationFactory, encodingWhitelist) val desObj = des.deserialize(bytes) assertTrue(Objects.deepEquals(obj, desObj) == expectedEqual) // Now repeat with a re-used factory - val ser2 = SerializationOutput(factory) - val des2 = DeserializationInput(factory) + val ser2 = SerializationOutput(factory, compression) + val des2 = DeserializationInput(factory, encodingWhitelist) val desObj2 = des2.deserialize(ser2.serialize(obj)) assertTrue(Objects.deepEquals(obj, desObj2) == expectedEqual) assertTrue(Objects.deepEquals(desObj, desObj2) == expectDeserializedEqual) @@ -432,9 +438,9 @@ class SerializationOutputTests { @Test fun `class constructor is invoked on deserialisation`() { - val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())) - val des = DeserializationInput(ser.serializerFactory) - + compression == null || return // Manipulation of serialized bytes is invalid if they're compressed. + val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()), compression) + val des = DeserializationInput(ser.serializerFactory, encodingWhitelist) val serialisedOne = ser.serialize(NonZeroByte(1)).bytes val serialisedTwo = ser.serialize(NonZeroByte(2)).bytes @@ -1116,6 +1122,29 @@ class SerializationOutputTests { val c = C(Amount(100, BigDecimal("1.5"), Currency.getInstance("USD"))) // were the issue not fixed we'd blow up here - SerializationOutput(factory).serialize(c) + SerializationOutput(factory, compression).serialize(c) + } + + @Test + fun `compression has the desired effect`() { + compression ?: return + val factory = defaultFactory() + val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it } + val compressed = SerializationOutput(factory, compression).serialize(data) + assertEquals(.5, compressed.size.toDouble() / data.size, .03) + assertArrayEquals(data, DeserializationInput(factory, encodingWhitelist).deserialize(compressed)) + } + + @Test + fun `a particular encoding can be banned for deserialization`() { + compression ?: return + val factory = defaultFactory() + doReturn(false).whenever(encodingWhitelist).acceptEncoding(compression) + val compressed = SerializationOutput(factory, compression).serialize("whatever") + val input = DeserializationInput(factory, encodingWhitelist) + catchThrowable { input.deserialize(compressed) }.run { + assertSame(NotSerializableException::class.java, javaClass) + assertEquals(encodingNotPermittedFormat.format(compression), message) + } } } \ No newline at end of file diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreamsTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreamsTest.kt index ac9779f828..d8eedd305d 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreamsTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoStreamsTest.kt @@ -1,12 +1,16 @@ package net.corda.nodeapi.internal.serialization.kryo +import net.corda.core.internal.declaredField +import org.assertj.core.api.Assertions.catchThrowable import org.junit.Assert.assertArrayEquals import org.junit.Test import java.io.* +import java.nio.BufferOverflowException import java.util.* import java.util.zip.DeflaterOutputStream import java.util.zip.InflaterInputStream import kotlin.test.assertEquals +import kotlin.test.assertSame class KryoStreamsTest { class NegOutputStream(private val stream: OutputStream) : OutputStream() { @@ -57,4 +61,37 @@ class KryoStreamsTest { assertEquals(-1, read()) } } + + @Test + fun `ByteBufferOutputStream works`() { + val stream = ByteBufferOutputStream(3) + stream.write("abc".toByteArray()) + val getBuf = stream.declaredField(ByteArrayOutputStream::class, "buf")::value + assertEquals(3, getBuf().size) + repeat(2) { + assertSame(BufferOverflowException::class.java, catchThrowable { + stream.alsoAsByteBuffer(9) { + it.put("0123456789".toByteArray()) + } + }.javaClass) + assertEquals(3 + 9, getBuf().size) + } + // This time make too much space: + stream.alsoAsByteBuffer(11) { + it.put("0123456789".toByteArray()) + } + stream.write("def".toByteArray()) + assertArrayEquals("abc0123456789def".toByteArray(), stream.toByteArray()) + } + + @Test + fun `ByteBufferOutputStream discards data after final position`() { + val stream = ByteBufferOutputStream(0) + stream.alsoAsByteBuffer(10) { + it.put("0123456789".toByteArray()) + it.position(5) + } + stream.write("def".toByteArray()) + assertArrayEquals("01234def".toByteArray(), stream.toByteArray()) + } }