mirror of
https://github.com/corda/corda.git
synced 2024-12-18 20:47:57 +00:00
ENT-2439 Fix compression in serialization (#3825)
* ENT-2439 Fix compression in serialization
This commit is contained in:
parent
96d645c316
commit
1d05c16942
@ -198,6 +198,11 @@ interface SerializationContext {
|
||||
*/
|
||||
fun withEncoding(encoding: SerializationEncoding?): SerializationContext
|
||||
|
||||
/**
|
||||
* A shallow copy of this context but with the given encoding whitelist.
|
||||
*/
|
||||
fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist): SerializationContext
|
||||
|
||||
/**
|
||||
* The use case that we are serializing for, since it influences the implementations chosen.
|
||||
*/
|
||||
|
@ -44,7 +44,6 @@ class TestScheme : AbstractKryoSerializationScheme() {
|
||||
override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
|
||||
|
||||
override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
|
||||
|
||||
}
|
||||
|
||||
@RunWith(Parameterized::class)
|
||||
@ -89,7 +88,6 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
|
||||
assertThat(bits.deserialize(factory, context)).isEqualTo(Person("bob", null))
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun `serialised form is stable when the same object instance is added to the deserialised object graph`() {
|
||||
val noReferencesContext = context.withoutReferences()
|
||||
@ -356,4 +354,16 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
|
||||
assertEquals(encodingNotPermittedFormat.format(compression), message)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `compression reduces number of bytes significantly`() {
|
||||
class Holder(val holder: ByteArray)
|
||||
|
||||
val obj = Holder(ByteArray(20000))
|
||||
val uncompressedSize = obj.serialize(factory, context.withEncoding(null)).size
|
||||
val compressedSize = obj.serialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
|
||||
// If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade.
|
||||
assertEquals(20222, uncompressedSize)
|
||||
assertEquals(1111, compressedSize)
|
||||
}
|
||||
}
|
@ -7,6 +7,7 @@ import net.corda.core.utilities.OpaqueBytes
|
||||
import net.corda.serialization.internal.OrdinalBits.OrdinalWriter
|
||||
import org.iq80.snappy.SnappyFramedInputStream
|
||||
import org.iq80.snappy.SnappyFramedOutputStream
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
import java.io.OutputStream
|
||||
import java.nio.ByteBuffer
|
||||
@ -44,7 +45,7 @@ enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter {
|
||||
override fun wrap(stream: InputStream) = InflaterInputStream(stream)
|
||||
},
|
||||
SNAPPY {
|
||||
override fun wrap(stream: OutputStream) = SnappyFramedOutputStream(stream)
|
||||
override fun wrap(stream: OutputStream) = FlushAverseOutputStream(SnappyFramedOutputStream(stream))
|
||||
override fun wrap(stream: InputStream) = SnappyFramedInputStream(stream, false)
|
||||
};
|
||||
|
||||
@ -58,3 +59,21 @@ enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter {
|
||||
}
|
||||
|
||||
const val encodingNotPermittedFormat = "Encoding not permitted: %s"
|
||||
|
||||
/**
|
||||
* Has an empty flush implementation. This is because Kryo keeps calling flush all the time, which stops the Snappy
|
||||
* stream from building up big chunks to compress and instead keeps compressing small chunks giving terrible compression ratio.
|
||||
*/
|
||||
class FlushAverseOutputStream(private val delegate: OutputStream) : OutputStream() {
|
||||
@Throws(IOException::class)
|
||||
override fun write(b: Int) = delegate.write(b)
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun write(b: ByteArray?, off: Int, len: Int) = delegate.write(b, off, len)
|
||||
|
||||
@Throws(IOException::class)
|
||||
override fun close() {
|
||||
delegate.flush()
|
||||
delegate.close()
|
||||
}
|
||||
}
|
||||
|
@ -67,6 +67,7 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe
|
||||
|
||||
override fun withPreferredSerializationVersion(magic: SerializationMagic) = copy(preferredSerializationVersion = magic)
|
||||
override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding)
|
||||
override fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist) = copy(encodingWhitelist = encodingWhitelist)
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -29,9 +29,8 @@ data class ObjectAndEnvelope<out T>(val obj: T, val envelope: Envelope)
|
||||
* instances and threads.
|
||||
*/
|
||||
@KeepForDJVM
|
||||
class DeserializationInput @JvmOverloads constructor(
|
||||
private val serializerFactory: SerializerFactory,
|
||||
private val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist
|
||||
class DeserializationInput constructor(
|
||||
private val serializerFactory: SerializerFactory
|
||||
) {
|
||||
private val objectHistory: MutableList<Any> = mutableListOf()
|
||||
private val logger = loggerFor<DeserializationInput>()
|
||||
@ -80,9 +79,9 @@ class DeserializationInput @JvmOverloads constructor(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@VisibleForTesting
|
||||
@Throws(AMQPNoTypeNotSerializableException::class)
|
||||
fun getEnvelope(byteSequence: ByteSequence) = getEnvelope(byteSequence, encodingWhitelist)
|
||||
fun getEnvelope(byteSequence: ByteSequence, context: SerializationContext) = getEnvelope(byteSequence, context.encodingWhitelist)
|
||||
|
||||
@Throws(
|
||||
AMQPNotSerializableException::class,
|
||||
@ -116,7 +115,7 @@ class DeserializationInput @JvmOverloads constructor(
|
||||
@Throws(NotSerializableException::class)
|
||||
fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationContext): T =
|
||||
des {
|
||||
val envelope = getEnvelope(bytes, encodingWhitelist)
|
||||
val envelope = getEnvelope(bytes, context.encodingWhitelist)
|
||||
|
||||
logger.trace("deserialize blob scheme=\"${envelope.schema.toString()}\"")
|
||||
|
||||
@ -130,7 +129,7 @@ class DeserializationInput @JvmOverloads constructor(
|
||||
clazz: Class<T>,
|
||||
context: SerializationContext
|
||||
): ObjectAndEnvelope<T> = des {
|
||||
val envelope = getEnvelope(bytes, encodingWhitelist)
|
||||
val envelope = getEnvelope(bytes, context.encodingWhitelist)
|
||||
// Now pick out the obj and schema from the envelope.
|
||||
ObjectAndEnvelope(
|
||||
clazz.cast(readObjectOrNull(
|
||||
|
@ -2,7 +2,6 @@ package net.corda.serialization.internal.amqp
|
||||
|
||||
import net.corda.core.KeepForDJVM
|
||||
import net.corda.core.serialization.SerializationContext
|
||||
import net.corda.core.serialization.SerializationEncoding
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.serialization.internal.CordaSerializationEncoding
|
||||
@ -28,9 +27,8 @@ data class BytesAndSchemas<T : Any>(
|
||||
* instances and threads.
|
||||
*/
|
||||
@KeepForDJVM
|
||||
open class SerializationOutput @JvmOverloads constructor(
|
||||
internal val serializerFactory: SerializerFactory,
|
||||
private val encoding: SerializationEncoding? = null
|
||||
open class SerializationOutput constructor(
|
||||
internal val serializerFactory: SerializerFactory
|
||||
) {
|
||||
companion object {
|
||||
private val logger = contextLogger()
|
||||
@ -90,6 +88,7 @@ open class SerializationOutput @JvmOverloads constructor(
|
||||
var stream: OutputStream = it
|
||||
try {
|
||||
amqpMagic.writeTo(stream)
|
||||
val encoding = context.encoding
|
||||
if (encoding != null) {
|
||||
SectionId.ENCODING.writeTo(stream)
|
||||
(encoding as CordaSerializationEncoding).writeTo(stream)
|
||||
|
@ -8,9 +8,9 @@ import net.corda.node.services.statemachine.DataSessionMessage
|
||||
import net.corda.serialization.internal.amqp.DeserializationInput
|
||||
import net.corda.serialization.internal.amqp.Envelope
|
||||
import net.corda.serialization.internal.amqp.SerializerFactory
|
||||
import net.corda.testing.core.SerializationEnvironmentRule
|
||||
import net.corda.testing.internal.amqpSpecific
|
||||
import net.corda.testing.internal.kryoSpecific
|
||||
import net.corda.testing.core.SerializationEnvironmentRule
|
||||
import org.assertj.core.api.Assertions
|
||||
import org.junit.Assert.assertArrayEquals
|
||||
import org.junit.Assert.assertEquals
|
||||
@ -28,7 +28,7 @@ class ListsSerializationTest {
|
||||
fun <T : Any> verifyEnvelope(serBytes: SerializedBytes<T>, envVerBody: (Envelope) -> Unit) =
|
||||
amqpSpecific("AMQP specific envelope verification") {
|
||||
val context = SerializationFactory.defaultFactory.defaultContext
|
||||
val envelope = DeserializationInput(SerializerFactory(context.whitelist, context.deserializationClassLoader)).getEnvelope(serBytes)
|
||||
val envelope = DeserializationInput(SerializerFactory(context.whitelist, context.deserializationClassLoader)).getEnvelope(serBytes, context)
|
||||
envVerBody(envelope)
|
||||
}
|
||||
}
|
||||
|
@ -219,8 +219,8 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
freshDeserializationFactory: SerializerFactory = defaultFactory(),
|
||||
expectedEqual: Boolean = true,
|
||||
expectDeserializedEqual: Boolean = true): T {
|
||||
val ser = SerializationOutput(factory, compression)
|
||||
val bytes = ser.serialize(obj)
|
||||
val ser = SerializationOutput(factory)
|
||||
val bytes = ser.serialize(obj, compression)
|
||||
|
||||
val decoder = DecoderImpl().apply {
|
||||
this.register(Envelope.DESCRIPTOR, Envelope.Companion)
|
||||
@ -241,14 +241,14 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
val result = decoder.readObject() as Envelope
|
||||
assertNotNull(result)
|
||||
}
|
||||
val des = DeserializationInput(freshDeserializationFactory, encodingWhitelist)
|
||||
val desObj = des.deserialize(bytes)
|
||||
val des = DeserializationInput(freshDeserializationFactory)
|
||||
val desObj = des.deserialize(bytes, testSerializationContext.withEncodingWhitelist(encodingWhitelist))
|
||||
assertTrue(Objects.deepEquals(obj, desObj) == expectedEqual)
|
||||
|
||||
// Now repeat with a re-used factory
|
||||
val ser2 = SerializationOutput(factory, compression)
|
||||
val des2 = DeserializationInput(factory, encodingWhitelist)
|
||||
val desObj2 = des2.deserialize(ser2.serialize(obj))
|
||||
val ser2 = SerializationOutput(factory)
|
||||
val des2 = DeserializationInput(factory)
|
||||
val desObj2 = des2.deserialize(ser2.serialize(obj, compression), testSerializationContext.withEncodingWhitelist(encodingWhitelist))
|
||||
assertTrue(Objects.deepEquals(obj, desObj2) == expectedEqual)
|
||||
assertTrue(Objects.deepEquals(desObj, desObj2) == expectDeserializedEqual)
|
||||
|
||||
@ -471,10 +471,10 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
@Test
|
||||
fun `class constructor is invoked on deserialisation`() {
|
||||
compression == null || return // Manipulation of serialized bytes is invalid if they're compressed.
|
||||
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()), compression)
|
||||
val des = DeserializationInput(ser.serializerFactory, encodingWhitelist)
|
||||
val serialisedOne = ser.serialize(NonZeroByte(1)).bytes
|
||||
val serialisedTwo = ser.serialize(NonZeroByte(2)).bytes
|
||||
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()))
|
||||
val des = DeserializationInput(ser.serializerFactory)
|
||||
val serialisedOne = ser.serialize(NonZeroByte(1), compression).bytes
|
||||
val serialisedTwo = ser.serialize(NonZeroByte(2), compression).bytes
|
||||
|
||||
// Find the index that holds the value byte
|
||||
val valueIndex = serialisedOne.zip(serialisedTwo).mapIndexedNotNull { index, (oneByte, twoByte) ->
|
||||
@ -485,12 +485,12 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
|
||||
// Double check
|
||||
copy[valueIndex] = 0x03
|
||||
assertThat(des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext).value).isEqualTo(3)
|
||||
assertThat(des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext.withEncodingWhitelist(encodingWhitelist)).value).isEqualTo(3)
|
||||
|
||||
// Now use the forbidden value
|
||||
copy[valueIndex] = 0x00
|
||||
assertThatExceptionOfType(NotSerializableException::class.java).isThrownBy {
|
||||
des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext)
|
||||
des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext.withEncodingWhitelist(encodingWhitelist))
|
||||
}.withStackTraceContaining("Zero not allowed")
|
||||
}
|
||||
|
||||
@ -1198,7 +1198,7 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
val c = C(Amount(100, BigDecimal("1.5"), Currency.getInstance("USD")))
|
||||
|
||||
// were the issue not fixed we'd blow up here
|
||||
SerializationOutput(factory, compression).serialize(c)
|
||||
SerializationOutput(factory).serialize(c, compression)
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -1206,9 +1206,9 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
compression ?: return
|
||||
val factory = defaultFactory()
|
||||
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
|
||||
val compressed = SerializationOutput(factory, compression).serialize(data)
|
||||
val compressed = SerializationOutput(factory).serialize(data, compression)
|
||||
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
|
||||
assertArrayEquals(data, DeserializationInput(factory, encodingWhitelist).deserialize(compressed))
|
||||
assertArrayEquals(data, DeserializationInput(factory).deserialize(compressed, testSerializationContext.withEncodingWhitelist(encodingWhitelist)))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -1216,9 +1216,9 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
compression ?: return
|
||||
val factory = defaultFactory()
|
||||
doReturn(false).whenever(encodingWhitelist).acceptEncoding(compression)
|
||||
val compressed = SerializationOutput(factory, compression).serialize("whatever")
|
||||
val input = DeserializationInput(factory, encodingWhitelist)
|
||||
catchThrowable { input.deserialize(compressed) }.run {
|
||||
val compressed = SerializationOutput(factory).serialize("whatever", compression)
|
||||
val input = DeserializationInput(factory)
|
||||
catchThrowable { input.deserialize(compressed, testSerializationContext.withEncodingWhitelist(encodingWhitelist)) }.run {
|
||||
assertSame(NotSerializableException::class.java, javaClass)
|
||||
assertEquals(encodingNotPermittedFormat.format(compression), message)
|
||||
}
|
||||
@ -1348,5 +1348,16 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
throw Error("Deserializing serialized \$C should not throw")
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `compression reduces number of bytes significantly`() {
|
||||
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()))
|
||||
val obj = ByteArray(20000)
|
||||
val uncompressedSize = ser.serialize(obj).bytes.size
|
||||
val compressedSize = ser.serialize(obj, CordaSerializationEncoding.SNAPPY).bytes.size
|
||||
// Ordinarily this might be considered high maintenance, but we promised wire compatibility, so they'd better not change!
|
||||
assertEquals(20059, uncompressedSize)
|
||||
assertEquals(1018, compressedSize)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,7 @@ import net.corda.core.internal.copyTo
|
||||
import net.corda.core.internal.div
|
||||
import net.corda.core.internal.packageName
|
||||
import net.corda.core.serialization.SerializationContext
|
||||
import net.corda.core.serialization.SerializationEncoding
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.utilities.OpaqueBytes
|
||||
import net.corda.serialization.internal.AllWhitelist
|
||||
@ -98,9 +99,9 @@ fun <T : Any> SerializationOutput.serializeAndReturnSchema(
|
||||
|
||||
|
||||
@Throws(NotSerializableException::class)
|
||||
fun <T : Any> SerializationOutput.serialize(obj: T): SerializedBytes<T> {
|
||||
fun <T : Any> SerializationOutput.serialize(obj: T, encoding: SerializationEncoding? = null): SerializedBytes<T> {
|
||||
try {
|
||||
return _serialize(obj, testSerializationContext)
|
||||
return _serialize(obj, testSerializationContext.withEncoding(encoding))
|
||||
} finally {
|
||||
andFinally()
|
||||
}
|
||||
|
@ -83,7 +83,7 @@ class BlobInspector : Runnable {
|
||||
?: throw IllegalArgumentException("Error: this input does not appear to be encoded in Corda's AMQP extended format, sorry.")
|
||||
|
||||
if (schema) {
|
||||
val envelope = DeserializationInput.getEnvelope(bytes.sequence())
|
||||
val envelope = DeserializationInput.getEnvelope(bytes.sequence(), SerializationDefaults.STORAGE_CONTEXT.encodingWhitelist)
|
||||
out.println(envelope.schema)
|
||||
out.println()
|
||||
out.println(envelope.transformsSchema)
|
||||
|
Loading…
Reference in New Issue
Block a user