ENT-2439 Fix compression in serialization (#3825)

* ENT-2439 Fix compression in serialization
This commit is contained in:
Rick Parker 2018-08-22 10:37:18 +01:00 committed by GitHub
parent 96d645c316
commit 1d05c16942
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 83 additions and 38 deletions

View File

@ -198,6 +198,11 @@ interface SerializationContext {
*/
fun withEncoding(encoding: SerializationEncoding?): SerializationContext
/**
* A shallow copy of this context but with the given encoding whitelist.
*/
fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist): SerializationContext
/**
* The use case that we are serializing for, since it influences the implementations chosen.
*/

View File

@ -44,7 +44,6 @@ class TestScheme : AbstractKryoSerializationScheme() {
override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
}
@RunWith(Parameterized::class)
@ -89,7 +88,6 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
assertThat(bits.deserialize(factory, context)).isEqualTo(Person("bob", null))
}
@Test
fun `serialised form is stable when the same object instance is added to the deserialised object graph`() {
val noReferencesContext = context.withoutReferences()
@ -356,4 +354,16 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
assertEquals(encodingNotPermittedFormat.format(compression), message)
}
}
@Test
fun `compression reduces number of bytes significantly`() {
class Holder(val holder: ByteArray)
val obj = Holder(ByteArray(20000))
val uncompressedSize = obj.serialize(factory, context.withEncoding(null)).size
val compressedSize = obj.serialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
// If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade.
assertEquals(20222, uncompressedSize)
assertEquals(1111, compressedSize)
}
}

View File

@ -7,6 +7,7 @@ import net.corda.core.utilities.OpaqueBytes
import net.corda.serialization.internal.OrdinalBits.OrdinalWriter
import org.iq80.snappy.SnappyFramedInputStream
import org.iq80.snappy.SnappyFramedOutputStream
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import java.nio.ByteBuffer
@ -44,7 +45,7 @@ enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter {
override fun wrap(stream: InputStream) = InflaterInputStream(stream)
},
SNAPPY {
override fun wrap(stream: OutputStream) = SnappyFramedOutputStream(stream)
override fun wrap(stream: OutputStream) = FlushAverseOutputStream(SnappyFramedOutputStream(stream))
override fun wrap(stream: InputStream) = SnappyFramedInputStream(stream, false)
};
@ -58,3 +59,21 @@ enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter {
}
const val encodingNotPermittedFormat = "Encoding not permitted: %s"
/**
* Has an empty flush implementation. This is because Kryo keeps calling flush all the time, which stops the Snappy
* stream from building up big chunks to compress and instead keeps compressing small chunks giving terrible compression ratio.
*/
class FlushAverseOutputStream(private val delegate: OutputStream) : OutputStream() {
@Throws(IOException::class)
override fun write(b: Int) = delegate.write(b)
@Throws(IOException::class)
override fun write(b: ByteArray?, off: Int, len: Int) = delegate.write(b, off, len)
@Throws(IOException::class)
override fun close() {
delegate.flush()
delegate.close()
}
}

View File

@ -67,6 +67,7 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe
override fun withPreferredSerializationVersion(magic: SerializationMagic) = copy(preferredSerializationVersion = magic)
override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding)
override fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist) = copy(encodingWhitelist = encodingWhitelist)
}
/*

View File

@ -29,9 +29,8 @@ data class ObjectAndEnvelope<out T>(val obj: T, val envelope: Envelope)
* instances and threads.
*/
@KeepForDJVM
class DeserializationInput @JvmOverloads constructor(
private val serializerFactory: SerializerFactory,
private val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist
class DeserializationInput constructor(
private val serializerFactory: SerializerFactory
) {
private val objectHistory: MutableList<Any> = mutableListOf()
private val logger = loggerFor<DeserializationInput>()
@ -80,9 +79,9 @@ class DeserializationInput @JvmOverloads constructor(
}
}
@VisibleForTesting
@Throws(AMQPNoTypeNotSerializableException::class)
fun getEnvelope(byteSequence: ByteSequence) = getEnvelope(byteSequence, encodingWhitelist)
fun getEnvelope(byteSequence: ByteSequence, context: SerializationContext) = getEnvelope(byteSequence, context.encodingWhitelist)
@Throws(
AMQPNotSerializableException::class,
@ -116,7 +115,7 @@ class DeserializationInput @JvmOverloads constructor(
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationContext): T =
des {
val envelope = getEnvelope(bytes, encodingWhitelist)
val envelope = getEnvelope(bytes, context.encodingWhitelist)
logger.trace("deserialize blob scheme=\"${envelope.schema.toString()}\"")
@ -130,7 +129,7 @@ class DeserializationInput @JvmOverloads constructor(
clazz: Class<T>,
context: SerializationContext
): ObjectAndEnvelope<T> = des {
val envelope = getEnvelope(bytes, encodingWhitelist)
val envelope = getEnvelope(bytes, context.encodingWhitelist)
// Now pick out the obj and schema from the envelope.
ObjectAndEnvelope(
clazz.cast(readObjectOrNull(

View File

@ -2,7 +2,6 @@ package net.corda.serialization.internal.amqp
import net.corda.core.KeepForDJVM
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationEncoding
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.contextLogger
import net.corda.serialization.internal.CordaSerializationEncoding
@ -28,9 +27,8 @@ data class BytesAndSchemas<T : Any>(
* instances and threads.
*/
@KeepForDJVM
open class SerializationOutput @JvmOverloads constructor(
internal val serializerFactory: SerializerFactory,
private val encoding: SerializationEncoding? = null
open class SerializationOutput constructor(
internal val serializerFactory: SerializerFactory
) {
companion object {
private val logger = contextLogger()
@ -90,6 +88,7 @@ open class SerializationOutput @JvmOverloads constructor(
var stream: OutputStream = it
try {
amqpMagic.writeTo(stream)
val encoding = context.encoding
if (encoding != null) {
SectionId.ENCODING.writeTo(stream)
(encoding as CordaSerializationEncoding).writeTo(stream)

View File

@ -8,9 +8,9 @@ import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.serialization.internal.amqp.DeserializationInput
import net.corda.serialization.internal.amqp.Envelope
import net.corda.serialization.internal.amqp.SerializerFactory
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.internal.amqpSpecific
import net.corda.testing.internal.kryoSpecific
import net.corda.testing.core.SerializationEnvironmentRule
import org.assertj.core.api.Assertions
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
@ -28,7 +28,7 @@ class ListsSerializationTest {
fun <T : Any> verifyEnvelope(serBytes: SerializedBytes<T>, envVerBody: (Envelope) -> Unit) =
amqpSpecific("AMQP specific envelope verification") {
val context = SerializationFactory.defaultFactory.defaultContext
val envelope = DeserializationInput(SerializerFactory(context.whitelist, context.deserializationClassLoader)).getEnvelope(serBytes)
val envelope = DeserializationInput(SerializerFactory(context.whitelist, context.deserializationClassLoader)).getEnvelope(serBytes, context)
envVerBody(envelope)
}
}

View File

@ -219,8 +219,8 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
freshDeserializationFactory: SerializerFactory = defaultFactory(),
expectedEqual: Boolean = true,
expectDeserializedEqual: Boolean = true): T {
val ser = SerializationOutput(factory, compression)
val bytes = ser.serialize(obj)
val ser = SerializationOutput(factory)
val bytes = ser.serialize(obj, compression)
val decoder = DecoderImpl().apply {
this.register(Envelope.DESCRIPTOR, Envelope.Companion)
@ -241,14 +241,14 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
val result = decoder.readObject() as Envelope
assertNotNull(result)
}
val des = DeserializationInput(freshDeserializationFactory, encodingWhitelist)
val desObj = des.deserialize(bytes)
val des = DeserializationInput(freshDeserializationFactory)
val desObj = des.deserialize(bytes, testSerializationContext.withEncodingWhitelist(encodingWhitelist))
assertTrue(Objects.deepEquals(obj, desObj) == expectedEqual)
// Now repeat with a re-used factory
val ser2 = SerializationOutput(factory, compression)
val des2 = DeserializationInput(factory, encodingWhitelist)
val desObj2 = des2.deserialize(ser2.serialize(obj))
val ser2 = SerializationOutput(factory)
val des2 = DeserializationInput(factory)
val desObj2 = des2.deserialize(ser2.serialize(obj, compression), testSerializationContext.withEncodingWhitelist(encodingWhitelist))
assertTrue(Objects.deepEquals(obj, desObj2) == expectedEqual)
assertTrue(Objects.deepEquals(desObj, desObj2) == expectDeserializedEqual)
@ -471,10 +471,10 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
@Test
fun `class constructor is invoked on deserialisation`() {
compression == null || return // Manipulation of serialized bytes is invalid if they're compressed.
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()), compression)
val des = DeserializationInput(ser.serializerFactory, encodingWhitelist)
val serialisedOne = ser.serialize(NonZeroByte(1)).bytes
val serialisedTwo = ser.serialize(NonZeroByte(2)).bytes
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()))
val des = DeserializationInput(ser.serializerFactory)
val serialisedOne = ser.serialize(NonZeroByte(1), compression).bytes
val serialisedTwo = ser.serialize(NonZeroByte(2), compression).bytes
// Find the index that holds the value byte
val valueIndex = serialisedOne.zip(serialisedTwo).mapIndexedNotNull { index, (oneByte, twoByte) ->
@ -485,12 +485,12 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
// Double check
copy[valueIndex] = 0x03
assertThat(des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext).value).isEqualTo(3)
assertThat(des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext.withEncodingWhitelist(encodingWhitelist)).value).isEqualTo(3)
// Now use the forbidden value
copy[valueIndex] = 0x00
assertThatExceptionOfType(NotSerializableException::class.java).isThrownBy {
des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext)
des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext.withEncodingWhitelist(encodingWhitelist))
}.withStackTraceContaining("Zero not allowed")
}
@ -1198,7 +1198,7 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
val c = C(Amount(100, BigDecimal("1.5"), Currency.getInstance("USD")))
// were the issue not fixed we'd blow up here
SerializationOutput(factory, compression).serialize(c)
SerializationOutput(factory).serialize(c, compression)
}
@Test
@ -1206,9 +1206,9 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
compression ?: return
val factory = defaultFactory()
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
val compressed = SerializationOutput(factory, compression).serialize(data)
val compressed = SerializationOutput(factory).serialize(data, compression)
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
assertArrayEquals(data, DeserializationInput(factory, encodingWhitelist).deserialize(compressed))
assertArrayEquals(data, DeserializationInput(factory).deserialize(compressed, testSerializationContext.withEncodingWhitelist(encodingWhitelist)))
}
@Test
@ -1216,9 +1216,9 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
compression ?: return
val factory = defaultFactory()
doReturn(false).whenever(encodingWhitelist).acceptEncoding(compression)
val compressed = SerializationOutput(factory, compression).serialize("whatever")
val input = DeserializationInput(factory, encodingWhitelist)
catchThrowable { input.deserialize(compressed) }.run {
val compressed = SerializationOutput(factory).serialize("whatever", compression)
val input = DeserializationInput(factory)
catchThrowable { input.deserialize(compressed, testSerializationContext.withEncodingWhitelist(encodingWhitelist)) }.run {
assertSame(NotSerializableException::class.java, javaClass)
assertEquals(encodingNotPermittedFormat.format(compression), message)
}
@ -1348,5 +1348,16 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
throw Error("Deserializing serialized \$C should not throw")
}
}
@Test
fun `compression reduces number of bytes significantly`() {
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()))
val obj = ByteArray(20000)
val uncompressedSize = ser.serialize(obj).bytes.size
val compressedSize = ser.serialize(obj, CordaSerializationEncoding.SNAPPY).bytes.size
// Ordinarily this might be considered high maintenance, but we promised wire compatibility, so they'd better not change!
assertEquals(20059, uncompressedSize)
assertEquals(1018, compressedSize)
}
}

View File

@ -4,6 +4,7 @@ import net.corda.core.internal.copyTo
import net.corda.core.internal.div
import net.corda.core.internal.packageName
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationEncoding
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.OpaqueBytes
import net.corda.serialization.internal.AllWhitelist
@ -98,9 +99,9 @@ fun <T : Any> SerializationOutput.serializeAndReturnSchema(
@Throws(NotSerializableException::class)
fun <T : Any> SerializationOutput.serialize(obj: T): SerializedBytes<T> {
fun <T : Any> SerializationOutput.serialize(obj: T, encoding: SerializationEncoding? = null): SerializedBytes<T> {
try {
return _serialize(obj, testSerializationContext)
return _serialize(obj, testSerializationContext.withEncoding(encoding))
} finally {
andFinally()
}

View File

@ -83,7 +83,7 @@ class BlobInspector : Runnable {
?: throw IllegalArgumentException("Error: this input does not appear to be encoded in Corda's AMQP extended format, sorry.")
if (schema) {
val envelope = DeserializationInput.getEnvelope(bytes.sequence())
val envelope = DeserializationInput.getEnvelope(bytes.sequence(), SerializationDefaults.STORAGE_CONTEXT.encodingWhitelist)
out.println(envelope.schema)
out.println()
out.println(envelope.transformsSchema)