From 1d05c16942a1f5815bc430232d4cc5d83e683ab8 Mon Sep 17 00:00:00 2001 From: Rick Parker Date: Wed, 22 Aug 2018 10:37:18 +0100 Subject: [PATCH] ENT-2439 Fix compression in serialization (#3825) * ENT-2439 Fix compression in serialization --- .../core/serialization/SerializationAPI.kt | 5 ++ .../node/serialization/kryo/KryoTests.kt | 14 +++++- .../internal/SerializationFormat.kt | 21 +++++++- .../internal/SerializationScheme.kt | 1 + .../internal/amqp/DeserializationInput.kt | 13 +++-- .../internal/amqp/SerializationOutput.kt | 7 ++- .../internal/ListsSerializationTest.kt | 4 +- .../internal/amqp/SerializationOutputTests.kt | 49 ++++++++++++------- .../internal/amqp/testutils/AMQPTestUtils.kt | 5 +- .../net/corda/blobinspector/BlobInspector.kt | 2 +- 10 files changed, 83 insertions(+), 38 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt index c1332c8954..c5df7f7069 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt @@ -198,6 +198,11 @@ interface SerializationContext { */ fun withEncoding(encoding: SerializationEncoding?): SerializationContext + /** + * A shallow copy of this context but with the given encoding whitelist. + */ + fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist): SerializationContext + /** * The use case that we are serializing for, since it influences the implementations chosen. */ diff --git a/node/src/test/kotlin/net/corda/node/serialization/kryo/KryoTests.kt b/node/src/test/kotlin/net/corda/node/serialization/kryo/KryoTests.kt index 4570ea1495..b15374b7ac 100644 --- a/node/src/test/kotlin/net/corda/node/serialization/kryo/KryoTests.kt +++ b/node/src/test/kotlin/net/corda/node/serialization/kryo/KryoTests.kt @@ -44,7 +44,6 @@ class TestScheme : AbstractKryoSerializationScheme() { override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException() override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException() - } @RunWith(Parameterized::class) @@ -89,7 +88,6 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { assertThat(bits.deserialize(factory, context)).isEqualTo(Person("bob", null)) } - @Test fun `serialised form is stable when the same object instance is added to the deserialised object graph`() { val noReferencesContext = context.withoutReferences() @@ -356,4 +354,16 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { assertEquals(encodingNotPermittedFormat.format(compression), message) } } + + @Test + fun `compression reduces number of bytes significantly`() { + class Holder(val holder: ByteArray) + + val obj = Holder(ByteArray(20000)) + val uncompressedSize = obj.serialize(factory, context.withEncoding(null)).size + val compressedSize = obj.serialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size + // If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade. + assertEquals(20222, uncompressedSize) + assertEquals(1111, compressedSize) + } } \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationFormat.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationFormat.kt index d275643fc3..7eb236f23d 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationFormat.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationFormat.kt @@ -7,6 +7,7 @@ import net.corda.core.utilities.OpaqueBytes import net.corda.serialization.internal.OrdinalBits.OrdinalWriter import org.iq80.snappy.SnappyFramedInputStream import org.iq80.snappy.SnappyFramedOutputStream +import java.io.IOException import java.io.InputStream import java.io.OutputStream import java.nio.ByteBuffer @@ -44,7 +45,7 @@ enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter { override fun wrap(stream: InputStream) = InflaterInputStream(stream) }, SNAPPY { - override fun wrap(stream: OutputStream) = SnappyFramedOutputStream(stream) + override fun wrap(stream: OutputStream) = FlushAverseOutputStream(SnappyFramedOutputStream(stream)) override fun wrap(stream: InputStream) = SnappyFramedInputStream(stream, false) }; @@ -58,3 +59,21 @@ enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter { } const val encodingNotPermittedFormat = "Encoding not permitted: %s" + +/** + * Has an empty flush implementation. This is because Kryo keeps calling flush all the time, which stops the Snappy + * stream from building up big chunks to compress and instead keeps compressing small chunks giving terrible compression ratio. + */ +class FlushAverseOutputStream(private val delegate: OutputStream) : OutputStream() { + @Throws(IOException::class) + override fun write(b: Int) = delegate.write(b) + + @Throws(IOException::class) + override fun write(b: ByteArray?, off: Int, len: Int) = delegate.write(b, off, len) + + @Throws(IOException::class) + override fun close() { + delegate.flush() + delegate.close() + } +} diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt index c3f9f1e31e..9c78e01926 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt @@ -67,6 +67,7 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe override fun withPreferredSerializationVersion(magic: SerializationMagic) = copy(preferredSerializationVersion = magic) override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding) + override fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist) = copy(encodingWhitelist = encodingWhitelist) } /* diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt index cfa2065c4b..126400f31d 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt @@ -29,9 +29,8 @@ data class ObjectAndEnvelope(val obj: T, val envelope: Envelope) * instances and threads. */ @KeepForDJVM -class DeserializationInput @JvmOverloads constructor( - private val serializerFactory: SerializerFactory, - private val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist +class DeserializationInput constructor( + private val serializerFactory: SerializerFactory ) { private val objectHistory: MutableList = mutableListOf() private val logger = loggerFor() @@ -80,9 +79,9 @@ class DeserializationInput @JvmOverloads constructor( } } - + @VisibleForTesting @Throws(AMQPNoTypeNotSerializableException::class) - fun getEnvelope(byteSequence: ByteSequence) = getEnvelope(byteSequence, encodingWhitelist) + fun getEnvelope(byteSequence: ByteSequence, context: SerializationContext) = getEnvelope(byteSequence, context.encodingWhitelist) @Throws( AMQPNotSerializableException::class, @@ -116,7 +115,7 @@ class DeserializationInput @JvmOverloads constructor( @Throws(NotSerializableException::class) fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationContext): T = des { - val envelope = getEnvelope(bytes, encodingWhitelist) + val envelope = getEnvelope(bytes, context.encodingWhitelist) logger.trace("deserialize blob scheme=\"${envelope.schema.toString()}\"") @@ -130,7 +129,7 @@ class DeserializationInput @JvmOverloads constructor( clazz: Class, context: SerializationContext ): ObjectAndEnvelope = des { - val envelope = getEnvelope(bytes, encodingWhitelist) + val envelope = getEnvelope(bytes, context.encodingWhitelist) // Now pick out the obj and schema from the envelope. ObjectAndEnvelope( clazz.cast(readObjectOrNull( diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationOutput.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationOutput.kt index 080fdbc8a1..d24e5ea77b 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationOutput.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationOutput.kt @@ -2,7 +2,6 @@ package net.corda.serialization.internal.amqp import net.corda.core.KeepForDJVM import net.corda.core.serialization.SerializationContext -import net.corda.core.serialization.SerializationEncoding import net.corda.core.serialization.SerializedBytes import net.corda.core.utilities.contextLogger import net.corda.serialization.internal.CordaSerializationEncoding @@ -28,9 +27,8 @@ data class BytesAndSchemas( * instances and threads. */ @KeepForDJVM -open class SerializationOutput @JvmOverloads constructor( - internal val serializerFactory: SerializerFactory, - private val encoding: SerializationEncoding? = null +open class SerializationOutput constructor( + internal val serializerFactory: SerializerFactory ) { companion object { private val logger = contextLogger() @@ -90,6 +88,7 @@ open class SerializationOutput @JvmOverloads constructor( var stream: OutputStream = it try { amqpMagic.writeTo(stream) + val encoding = context.encoding if (encoding != null) { SectionId.ENCODING.writeTo(stream) (encoding as CordaSerializationEncoding).writeTo(stream) diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/ListsSerializationTest.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/ListsSerializationTest.kt index 97ccea49a7..60f71520fe 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/ListsSerializationTest.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/ListsSerializationTest.kt @@ -8,9 +8,9 @@ import net.corda.node.services.statemachine.DataSessionMessage import net.corda.serialization.internal.amqp.DeserializationInput import net.corda.serialization.internal.amqp.Envelope import net.corda.serialization.internal.amqp.SerializerFactory +import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.internal.amqpSpecific import net.corda.testing.internal.kryoSpecific -import net.corda.testing.core.SerializationEnvironmentRule import org.assertj.core.api.Assertions import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertEquals @@ -28,7 +28,7 @@ class ListsSerializationTest { fun verifyEnvelope(serBytes: SerializedBytes, envVerBody: (Envelope) -> Unit) = amqpSpecific("AMQP specific envelope verification") { val context = SerializationFactory.defaultFactory.defaultContext - val envelope = DeserializationInput(SerializerFactory(context.whitelist, context.deserializationClassLoader)).getEnvelope(serBytes) + val envelope = DeserializationInput(SerializerFactory(context.whitelist, context.deserializationClassLoader)).getEnvelope(serBytes, context) envVerBody(envelope) } } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationOutputTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationOutputTests.kt index c0ffdfc04a..37bf6e8b47 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationOutputTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationOutputTests.kt @@ -219,8 +219,8 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi freshDeserializationFactory: SerializerFactory = defaultFactory(), expectedEqual: Boolean = true, expectDeserializedEqual: Boolean = true): T { - val ser = SerializationOutput(factory, compression) - val bytes = ser.serialize(obj) + val ser = SerializationOutput(factory) + val bytes = ser.serialize(obj, compression) val decoder = DecoderImpl().apply { this.register(Envelope.DESCRIPTOR, Envelope.Companion) @@ -241,14 +241,14 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi val result = decoder.readObject() as Envelope assertNotNull(result) } - val des = DeserializationInput(freshDeserializationFactory, encodingWhitelist) - val desObj = des.deserialize(bytes) + val des = DeserializationInput(freshDeserializationFactory) + val desObj = des.deserialize(bytes, testSerializationContext.withEncodingWhitelist(encodingWhitelist)) assertTrue(Objects.deepEquals(obj, desObj) == expectedEqual) // Now repeat with a re-used factory - val ser2 = SerializationOutput(factory, compression) - val des2 = DeserializationInput(factory, encodingWhitelist) - val desObj2 = des2.deserialize(ser2.serialize(obj)) + val ser2 = SerializationOutput(factory) + val des2 = DeserializationInput(factory) + val desObj2 = des2.deserialize(ser2.serialize(obj, compression), testSerializationContext.withEncodingWhitelist(encodingWhitelist)) assertTrue(Objects.deepEquals(obj, desObj2) == expectedEqual) assertTrue(Objects.deepEquals(desObj, desObj2) == expectDeserializedEqual) @@ -471,10 +471,10 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi @Test fun `class constructor is invoked on deserialisation`() { compression == null || return // Manipulation of serialized bytes is invalid if they're compressed. - val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()), compression) - val des = DeserializationInput(ser.serializerFactory, encodingWhitelist) - val serialisedOne = ser.serialize(NonZeroByte(1)).bytes - val serialisedTwo = ser.serialize(NonZeroByte(2)).bytes + val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())) + val des = DeserializationInput(ser.serializerFactory) + val serialisedOne = ser.serialize(NonZeroByte(1), compression).bytes + val serialisedTwo = ser.serialize(NonZeroByte(2), compression).bytes // Find the index that holds the value byte val valueIndex = serialisedOne.zip(serialisedTwo).mapIndexedNotNull { index, (oneByte, twoByte) -> @@ -485,12 +485,12 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi // Double check copy[valueIndex] = 0x03 - assertThat(des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext).value).isEqualTo(3) + assertThat(des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext.withEncodingWhitelist(encodingWhitelist)).value).isEqualTo(3) // Now use the forbidden value copy[valueIndex] = 0x00 assertThatExceptionOfType(NotSerializableException::class.java).isThrownBy { - des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext) + des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext.withEncodingWhitelist(encodingWhitelist)) }.withStackTraceContaining("Zero not allowed") } @@ -1198,7 +1198,7 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi val c = C(Amount(100, BigDecimal("1.5"), Currency.getInstance("USD"))) // were the issue not fixed we'd blow up here - SerializationOutput(factory, compression).serialize(c) + SerializationOutput(factory).serialize(c, compression) } @Test @@ -1206,9 +1206,9 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi compression ?: return val factory = defaultFactory() val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it } - val compressed = SerializationOutput(factory, compression).serialize(data) + val compressed = SerializationOutput(factory).serialize(data, compression) assertEquals(.5, compressed.size.toDouble() / data.size, .03) - assertArrayEquals(data, DeserializationInput(factory, encodingWhitelist).deserialize(compressed)) + assertArrayEquals(data, DeserializationInput(factory).deserialize(compressed, testSerializationContext.withEncodingWhitelist(encodingWhitelist))) } @Test @@ -1216,9 +1216,9 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi compression ?: return val factory = defaultFactory() doReturn(false).whenever(encodingWhitelist).acceptEncoding(compression) - val compressed = SerializationOutput(factory, compression).serialize("whatever") - val input = DeserializationInput(factory, encodingWhitelist) - catchThrowable { input.deserialize(compressed) }.run { + val compressed = SerializationOutput(factory).serialize("whatever", compression) + val input = DeserializationInput(factory) + catchThrowable { input.deserialize(compressed, testSerializationContext.withEncodingWhitelist(encodingWhitelist)) }.run { assertSame(NotSerializableException::class.java, javaClass) assertEquals(encodingNotPermittedFormat.format(compression), message) } @@ -1348,5 +1348,16 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi throw Error("Deserializing serialized \$C should not throw") } } + + @Test + fun `compression reduces number of bytes significantly`() { + val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())) + val obj = ByteArray(20000) + val uncompressedSize = ser.serialize(obj).bytes.size + val compressedSize = ser.serialize(obj, CordaSerializationEncoding.SNAPPY).bytes.size + // Ordinarily this might be considered high maintenance, but we promised wire compatibility, so they'd better not change! + assertEquals(20059, uncompressedSize) + assertEquals(1018, compressedSize) + } } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/testutils/AMQPTestUtils.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/testutils/AMQPTestUtils.kt index 9f300fd9e8..e3820e2d6b 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/testutils/AMQPTestUtils.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/testutils/AMQPTestUtils.kt @@ -4,6 +4,7 @@ import net.corda.core.internal.copyTo import net.corda.core.internal.div import net.corda.core.internal.packageName import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationEncoding import net.corda.core.serialization.SerializedBytes import net.corda.core.utilities.OpaqueBytes import net.corda.serialization.internal.AllWhitelist @@ -98,9 +99,9 @@ fun SerializationOutput.serializeAndReturnSchema( @Throws(NotSerializableException::class) -fun SerializationOutput.serialize(obj: T): SerializedBytes { +fun SerializationOutput.serialize(obj: T, encoding: SerializationEncoding? = null): SerializedBytes { try { - return _serialize(obj, testSerializationContext) + return _serialize(obj, testSerializationContext.withEncoding(encoding)) } finally { andFinally() } diff --git a/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt b/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt index 94e1a45747..655ddb830c 100644 --- a/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt +++ b/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt @@ -83,7 +83,7 @@ class BlobInspector : Runnable { ?: throw IllegalArgumentException("Error: this input does not appear to be encoded in Corda's AMQP extended format, sorry.") if (schema) { - val envelope = DeserializationInput.getEnvelope(bytes.sequence()) + val envelope = DeserializationInput.getEnvelope(bytes.sequence(), SerializationDefaults.STORAGE_CONTEXT.encodingWhitelist) out.println(envelope.schema) out.println() out.println(envelope.transformsSchema)