mirror of
https://github.com/corda/corda.git
synced 2025-06-23 01:19:00 +00:00
ENT-2439 Fix compression in serialization (#3825)
* ENT-2439 Fix compression in serialization
This commit is contained in:
@ -8,9 +8,9 @@ import net.corda.node.services.statemachine.DataSessionMessage
|
||||
import net.corda.serialization.internal.amqp.DeserializationInput
|
||||
import net.corda.serialization.internal.amqp.Envelope
|
||||
import net.corda.serialization.internal.amqp.SerializerFactory
|
||||
import net.corda.testing.core.SerializationEnvironmentRule
|
||||
import net.corda.testing.internal.amqpSpecific
|
||||
import net.corda.testing.internal.kryoSpecific
|
||||
import net.corda.testing.core.SerializationEnvironmentRule
|
||||
import org.assertj.core.api.Assertions
|
||||
import org.junit.Assert.assertArrayEquals
|
||||
import org.junit.Assert.assertEquals
|
||||
@ -28,7 +28,7 @@ class ListsSerializationTest {
|
||||
fun <T : Any> verifyEnvelope(serBytes: SerializedBytes<T>, envVerBody: (Envelope) -> Unit) =
|
||||
amqpSpecific("AMQP specific envelope verification") {
|
||||
val context = SerializationFactory.defaultFactory.defaultContext
|
||||
val envelope = DeserializationInput(SerializerFactory(context.whitelist, context.deserializationClassLoader)).getEnvelope(serBytes)
|
||||
val envelope = DeserializationInput(SerializerFactory(context.whitelist, context.deserializationClassLoader)).getEnvelope(serBytes, context)
|
||||
envVerBody(envelope)
|
||||
}
|
||||
}
|
||||
|
@ -219,8 +219,8 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
freshDeserializationFactory: SerializerFactory = defaultFactory(),
|
||||
expectedEqual: Boolean = true,
|
||||
expectDeserializedEqual: Boolean = true): T {
|
||||
val ser = SerializationOutput(factory, compression)
|
||||
val bytes = ser.serialize(obj)
|
||||
val ser = SerializationOutput(factory)
|
||||
val bytes = ser.serialize(obj, compression)
|
||||
|
||||
val decoder = DecoderImpl().apply {
|
||||
this.register(Envelope.DESCRIPTOR, Envelope.Companion)
|
||||
@ -241,14 +241,14 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
val result = decoder.readObject() as Envelope
|
||||
assertNotNull(result)
|
||||
}
|
||||
val des = DeserializationInput(freshDeserializationFactory, encodingWhitelist)
|
||||
val desObj = des.deserialize(bytes)
|
||||
val des = DeserializationInput(freshDeserializationFactory)
|
||||
val desObj = des.deserialize(bytes, testSerializationContext.withEncodingWhitelist(encodingWhitelist))
|
||||
assertTrue(Objects.deepEquals(obj, desObj) == expectedEqual)
|
||||
|
||||
// Now repeat with a re-used factory
|
||||
val ser2 = SerializationOutput(factory, compression)
|
||||
val des2 = DeserializationInput(factory, encodingWhitelist)
|
||||
val desObj2 = des2.deserialize(ser2.serialize(obj))
|
||||
val ser2 = SerializationOutput(factory)
|
||||
val des2 = DeserializationInput(factory)
|
||||
val desObj2 = des2.deserialize(ser2.serialize(obj, compression), testSerializationContext.withEncodingWhitelist(encodingWhitelist))
|
||||
assertTrue(Objects.deepEquals(obj, desObj2) == expectedEqual)
|
||||
assertTrue(Objects.deepEquals(desObj, desObj2) == expectDeserializedEqual)
|
||||
|
||||
@ -471,10 +471,10 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
@Test
|
||||
fun `class constructor is invoked on deserialisation`() {
|
||||
compression == null || return // Manipulation of serialized bytes is invalid if they're compressed.
|
||||
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()), compression)
|
||||
val des = DeserializationInput(ser.serializerFactory, encodingWhitelist)
|
||||
val serialisedOne = ser.serialize(NonZeroByte(1)).bytes
|
||||
val serialisedTwo = ser.serialize(NonZeroByte(2)).bytes
|
||||
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()))
|
||||
val des = DeserializationInput(ser.serializerFactory)
|
||||
val serialisedOne = ser.serialize(NonZeroByte(1), compression).bytes
|
||||
val serialisedTwo = ser.serialize(NonZeroByte(2), compression).bytes
|
||||
|
||||
// Find the index that holds the value byte
|
||||
val valueIndex = serialisedOne.zip(serialisedTwo).mapIndexedNotNull { index, (oneByte, twoByte) ->
|
||||
@ -485,12 +485,12 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
|
||||
// Double check
|
||||
copy[valueIndex] = 0x03
|
||||
assertThat(des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext).value).isEqualTo(3)
|
||||
assertThat(des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext.withEncodingWhitelist(encodingWhitelist)).value).isEqualTo(3)
|
||||
|
||||
// Now use the forbidden value
|
||||
copy[valueIndex] = 0x00
|
||||
assertThatExceptionOfType(NotSerializableException::class.java).isThrownBy {
|
||||
des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext)
|
||||
des.deserialize(OpaqueBytes(copy), NonZeroByte::class.java, testSerializationContext.withEncodingWhitelist(encodingWhitelist))
|
||||
}.withStackTraceContaining("Zero not allowed")
|
||||
}
|
||||
|
||||
@ -1198,7 +1198,7 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
val c = C(Amount(100, BigDecimal("1.5"), Currency.getInstance("USD")))
|
||||
|
||||
// were the issue not fixed we'd blow up here
|
||||
SerializationOutput(factory, compression).serialize(c)
|
||||
SerializationOutput(factory).serialize(c, compression)
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -1206,9 +1206,9 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
compression ?: return
|
||||
val factory = defaultFactory()
|
||||
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
|
||||
val compressed = SerializationOutput(factory, compression).serialize(data)
|
||||
val compressed = SerializationOutput(factory).serialize(data, compression)
|
||||
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
|
||||
assertArrayEquals(data, DeserializationInput(factory, encodingWhitelist).deserialize(compressed))
|
||||
assertArrayEquals(data, DeserializationInput(factory).deserialize(compressed, testSerializationContext.withEncodingWhitelist(encodingWhitelist)))
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -1216,9 +1216,9 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
compression ?: return
|
||||
val factory = defaultFactory()
|
||||
doReturn(false).whenever(encodingWhitelist).acceptEncoding(compression)
|
||||
val compressed = SerializationOutput(factory, compression).serialize("whatever")
|
||||
val input = DeserializationInput(factory, encodingWhitelist)
|
||||
catchThrowable { input.deserialize(compressed) }.run {
|
||||
val compressed = SerializationOutput(factory).serialize("whatever", compression)
|
||||
val input = DeserializationInput(factory)
|
||||
catchThrowable { input.deserialize(compressed, testSerializationContext.withEncodingWhitelist(encodingWhitelist)) }.run {
|
||||
assertSame(NotSerializableException::class.java, javaClass)
|
||||
assertEquals(encodingNotPermittedFormat.format(compression), message)
|
||||
}
|
||||
@ -1348,5 +1348,16 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi
|
||||
throw Error("Deserializing serialized \$C should not throw")
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `compression reduces number of bytes significantly`() {
|
||||
val ser = SerializationOutput(SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()))
|
||||
val obj = ByteArray(20000)
|
||||
val uncompressedSize = ser.serialize(obj).bytes.size
|
||||
val compressedSize = ser.serialize(obj, CordaSerializationEncoding.SNAPPY).bytes.size
|
||||
// Ordinarily this might be considered high maintenance, but we promised wire compatibility, so they'd better not change!
|
||||
assertEquals(20059, uncompressedSize)
|
||||
assertEquals(1018, compressedSize)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -4,6 +4,7 @@ import net.corda.core.internal.copyTo
|
||||
import net.corda.core.internal.div
|
||||
import net.corda.core.internal.packageName
|
||||
import net.corda.core.serialization.SerializationContext
|
||||
import net.corda.core.serialization.SerializationEncoding
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.utilities.OpaqueBytes
|
||||
import net.corda.serialization.internal.AllWhitelist
|
||||
@ -98,9 +99,9 @@ fun <T : Any> SerializationOutput.serializeAndReturnSchema(
|
||||
|
||||
|
||||
@Throws(NotSerializableException::class)
|
||||
fun <T : Any> SerializationOutput.serialize(obj: T): SerializedBytes<T> {
|
||||
fun <T : Any> SerializationOutput.serialize(obj: T, encoding: SerializationEncoding? = null): SerializedBytes<T> {
|
||||
try {
|
||||
return _serialize(obj, testSerializationContext)
|
||||
return _serialize(obj, testSerializationContext.withEncoding(encoding))
|
||||
} finally {
|
||||
andFinally()
|
||||
}
|
||||
|
Reference in New Issue
Block a user