ENT-2439 Fix compression in serialization (#3825)

* ENT-2439 Fix compression in serialization
This commit is contained in:
Rick Parker
2018-08-22 10:37:18 +01:00
committed by GitHub
parent 96d645c316
commit 1d05c16942
10 changed files with 83 additions and 38 deletions

View File

@ -8,9 +8,9 @@ import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.serialization.internal.amqp.DeserializationInput
import net.corda.serialization.internal.amqp.Envelope
import net.corda.serialization.internal.amqp.SerializerFactory
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.internal.amqpSpecific
import net.corda.testing.internal.kryoSpecific
import net.corda.testing.core.SerializationEnvironmentRule
import org.assertj.core.api.Assertions
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
@ -28,7 +28,7 @@ class ListsSerializationTest {
fun <T : Any> verifyEnvelope(serBytes: SerializedBytes<T>, envVerBody: (Envelope) -> Unit) =
amqpSpecific("AMQP specific envelope verification") {
val context = SerializationFactory.defaultFactory.defaultContext
val envelope = DeserializationInput(SerializerFactory(context.whitelist, context.deserializationClassLoader)).getEnvelope(serBytes)
val envelope = DeserializationInput(SerializerFactory(context.whitelist, context.deserializationClassLoader)).getEnvelope(serBytes, context)
envVerBody(envelope)
}
}

View File

@ -219,8 +219,8 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
freshDeserializationFactory: SerializerFactory = defaultFactory(),
expectedEqual: Boolean = true,
expectDeserializedEqual: Boolean = true): T {
val ser = SerializationOutput(factory, compression)
val bytes = ser.serialize(obj)
val ser = SerializationOutput(factory)
val bytes = ser.serialize(obj, compression)
val decoder = DecoderImpl().apply {
this.register(Envelope.DESCRIPTOR, Envelope.Companion)
@ -241,14 +241,14 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
val result = decoder.readObject() as Envelope
assertNotNull(result)
}
val des = DeserializationInput(freshDeserializationFactory, encodingWhitelist)
val desObj = des.deserialize(bytes)
val des = DeserializationInput(freshDeserializationFactory)
val desObj = des.deserialize(bytes, testSerializationContext.withEncodingWhitelist(encodingWhitelist))
assertTrue(Objects.deepEquals(obj, desObj) == expectedEqual)
// Now repeat with a re-used factory
val ser2 = SerializationOutput(factory, compression)
val des2 = DeserializationInput(factory, encodingWhitelist)
val desObj2 = des2.deserialize(ser2.serialize(obj))
val ser2 = SerializationOutput(factory)
val des2 = DeserializationInput(factory)
val desObj2 = des2.deserialize(ser2.serialize(obj, compression), testSerializationContext.withEncodingWhitelist(encodingWhitelist))
assertTrue(Objects.deepEquals(obj, desObj2) == expectedEqual)
assertTrue(Objects.deepEquals(desObj, desObj2) == expectDeserializedEqual)
@ -471,10 +471,10 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
@Test
fun `class constructor is invoked on deserialisation`() {
compression == null || return // Manipulation of serialized bytes is invalid if they're compressed.
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()), compression)
val des = DeserializationInput(ser.serializerFactory, encodingWhitelist)
val serialisedOne = ser.serialize(NonZeroByte(1)).bytes
val serialisedTwo = ser.serialize(NonZeroByte(2)).bytes
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()))
val des = DeserializationInput(ser.serializerFactory)
val serialisedOne = ser.serialize(NonZeroByte(1), compression).bytes
val serialisedTwo = ser.serialize(NonZeroByte(2), compression).bytes
// Find the index that holds the value byte
val valueIndex = serialisedOne.zip(serialisedTwo).mapIndexedNotNull { index, (oneByte, twoByte) ->
@ -485,12 +485,12 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
// Double check
copy[valueIndex] = 0x03
assertThat(des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext).value).isEqualTo(3)
assertThat(des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext.withEncodingWhitelist(encodingWhitelist)).value).isEqualTo(3)
// Now use the forbidden value
copy[valueIndex] = 0x00
assertThatExceptionOfType(NotSerializableException::class.java).isThrownBy {
des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext)
des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext.withEncodingWhitelist(encodingWhitelist))
}.withStackTraceContaining("Zero not allowed")
}
@ -1198,7 +1198,7 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
val c = C(Amount(100, BigDecimal("1.5"), Currency.getInstance("USD")))
// were the issue not fixed we'd blow up here
SerializationOutput(factory, compression).serialize(c)
SerializationOutput(factory).serialize(c, compression)
}
@Test
@ -1206,9 +1206,9 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
compression ?: return
val factory = defaultFactory()
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
val compressed = SerializationOutput(factory, compression).serialize(data)
val compressed = SerializationOutput(factory).serialize(data, compression)
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
assertArrayEquals(data, DeserializationInput(factory, encodingWhitelist).deserialize(compressed))
assertArrayEquals(data, DeserializationInput(factory).deserialize(compressed, testSerializationContext.withEncodingWhitelist(encodingWhitelist)))
}
@Test
@ -1216,9 +1216,9 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
compression ?: return
val factory = defaultFactory()
doReturn(false).whenever(encodingWhitelist).acceptEncoding(compression)
val compressed = SerializationOutput(factory, compression).serialize("whatever")
val input = DeserializationInput(factory, encodingWhitelist)
catchThrowable { input.deserialize(compressed) }.run {
val compressed = SerializationOutput(factory).serialize("whatever", compression)
val input = DeserializationInput(factory)
catchThrowable { input.deserialize(compressed, testSerializationContext.withEncodingWhitelist(encodingWhitelist)) }.run {
assertSame(NotSerializableException::class.java, javaClass)
assertEquals(encodingNotPermittedFormat.format(compression), message)
}
@ -1348,5 +1348,16 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
throw Error("Deserializing serialized \$C should not throw")
}
}
@Test
fun `compression reduces number of bytes significantly`() {
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()))
val obj = ByteArray(20000)
val uncompressedSize = ser.serialize(obj).bytes.size
val compressedSize = ser.serialize(obj, CordaSerializationEncoding.SNAPPY).bytes.size
// Ordinarily this might be considered high maintenance, but we promised wire compatibility, so they'd better not change!
assertEquals(20059, uncompressedSize)
assertEquals(1018, compressedSize)
}
}

View File

@ -4,6 +4,7 @@ import net.corda.core.internal.copyTo
import net.corda.core.internal.div
import net.corda.core.internal.packageName
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationEncoding
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.OpaqueBytes
import net.corda.serialization.internal.AllWhitelist
@ -98,9 +99,9 @@ fun <T : Any> SerializationOutput.serializeAndReturnSchema(
@Throws(NotSerializableException::class)
fun <T : Any> SerializationOutput.serialize(obj: T): SerializedBytes<T> {
fun <T : Any> SerializationOutput.serialize(obj: T, encoding: SerializationEncoding? = null): SerializedBytes<T> {
try {
return _serialize(obj, testSerializationContext)
return _serialize(obj, testSerializationContext.withEncoding(encoding))
} finally {
andFinally()
}