Implement zstandard compression. Parameterise the test.

This commit is contained in:
rick.parker 2022-11-04 16:13:36 +00:00
parent 84bf3d5639
commit 137734991b
7 changed files with 112 additions and 11 deletions

View File

@ -1,5 +1,6 @@
package net.corda.finance.flows package net.corda.finance.flows
import com.github.luben.zstd.ZstdDictTrainer
import net.corda.core.serialization.ExternalSchema import net.corda.core.serialization.ExternalSchema
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializationFactory import net.corda.core.serialization.SerializationFactory
@ -22,14 +23,57 @@ import net.corda.serialization.internal.amqp.custom.PublicKeySerializer
import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.SerializationEnvironmentRule
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertNotNull import kotlin.test.assertNotNull
import kotlin.test.assertTrue import kotlin.test.assertTrue
import kotlin.test.fail import kotlin.test.fail
enum class SchemaType {
REPEAT,
SINGLE,
NONE;
}
// TODO: If this type of testing gets momentum, we can create a mini-framework that rides through list of files // TODO: If this type of testing gets momentum, we can create a mini-framework that rides through list of files
// and performs necessary validation on all of them. // and performs necessary validation on all of them.
class CompatibilityTest { @RunWith(Parameterized::class)
class CompatibilityTest(val encoding: CordaSerializationEncoding?, val schemaType: SchemaType, val useDictionary: Boolean, val useIntegerFingerprints: Boolean) {
companion object {
@Parameterized.Parameters(name = "encoding: {0}, schemaType: {1}, useDictionary: {2}, useIntegerFingerprints: {3}")
@JvmStatic
fun data(): List<Array<Any?>> = listOf(
arrayOf<Any?>(null, SchemaType.REPEAT, false, false),
arrayOf<Any?>(null, SchemaType.SINGLE, false, false),
arrayOf<Any?>(null, SchemaType.NONE, false, false),
arrayOf<Any?>(null, SchemaType.REPEAT, false, true),
arrayOf<Any?>(null, SchemaType.SINGLE, false, true),
arrayOf<Any?>(null, SchemaType.NONE, false, true),
arrayOf<Any?>(CordaSerializationEncoding.DEFLATE, SchemaType.REPEAT, false, false),
arrayOf<Any?>(CordaSerializationEncoding.DEFLATE, SchemaType.SINGLE, false, false),
arrayOf<Any?>(CordaSerializationEncoding.DEFLATE, SchemaType.NONE, false, false),
arrayOf<Any?>(CordaSerializationEncoding.DEFLATE, SchemaType.REPEAT, false, true),
arrayOf<Any?>(CordaSerializationEncoding.DEFLATE, SchemaType.SINGLE, false, true),
arrayOf<Any?>(CordaSerializationEncoding.DEFLATE, SchemaType.NONE, false, true),
arrayOf<Any?>(CordaSerializationEncoding.SNAPPY, SchemaType.REPEAT, false, false),
arrayOf<Any?>(CordaSerializationEncoding.SNAPPY, SchemaType.SINGLE, false, false),
arrayOf<Any?>(CordaSerializationEncoding.SNAPPY, SchemaType.NONE, false, false),
arrayOf<Any?>(CordaSerializationEncoding.SNAPPY, SchemaType.REPEAT, false, true),
arrayOf<Any?>(CordaSerializationEncoding.SNAPPY, SchemaType.SINGLE, false, true),
arrayOf<Any?>(CordaSerializationEncoding.SNAPPY, SchemaType.NONE, false, true),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.REPEAT, false, false),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.SINGLE, false, false),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.NONE, false, false),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.REPEAT, true, false),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.SINGLE, true, false),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.NONE, true, false),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.REPEAT, true, true),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.SINGLE, true, true),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.NONE, true, true)
)
}
@Rule @Rule
@JvmField @JvmField
@ -47,7 +91,7 @@ class CompatibilityTest {
assertNotNull(inputStream) assertNotNull(inputStream)
val inByteArray: ByteArray = inputStream.readBytes() val inByteArray: ByteArray = inputStream.readBytes()
println("Original size = ${inByteArray.size}") //println("Original size = ${inByteArray.size}")
val input = DeserializationInput(serializerFactory) val input = DeserializationInput(serializerFactory)
val (transaction, envelope) = input.deserializeAndReturnEnvelope( val (transaction, envelope) = input.deserializeAndReturnEnvelope(
@ -60,7 +104,27 @@ class CompatibilityTest {
assertEquals(1, commands.size) assertEquals(1, commands.size)
assertTrue(commands.first().value is Cash.Commands.Issue) assertTrue(commands.first().value is Cash.Commands.Issue)
val context = SerializationDefaults.STORAGE_CONTEXT.withExternalSchema(ExternalSchema()).withIntegerFingerprint() val networkParams = javaClass.classLoader.getResourceAsStream("networkParams.r3corda.6a6b6f256").readBytes()
val trainer = ZstdDictTrainer(128 * 1024, 128 * 1024)
while (useDictionary) {
if (!trainer.addSample(inByteArray)) break
if (!trainer.addSample(networkParams)) break
}
val dict = if (useDictionary) trainer.trainSamples() else ByteArray(0)
val context = SerializationDefaults.STORAGE_CONTEXT.let {
if (schemaType != SchemaType.REPEAT) {
it.withExternalSchema(ExternalSchema())
} else {
it
}
}.let {
if (useIntegerFingerprints) {
it.withIntegerFingerprint()
} else {
it
}
}
//.withExternalSchema(ExternalSchema()).withIntegerFingerprint()
val newWtx = SerializationFactory.defaultFactory.asCurrent { val newWtx = SerializationFactory.defaultFactory.asCurrent {
withCurrentContext(context) { withCurrentContext(context) {
WireTransaction(transaction.tx.componentGroups.map { cg: ComponentGroup -> WireTransaction(transaction.tx.componentGroups.map { cg: ComponentGroup ->
@ -78,10 +142,27 @@ class CompatibilityTest {
// Serialize back and check that representation is byte-to-byte identical to what it was originally. // Serialize back and check that representation is byte-to-byte identical to what it was originally.
val output = SerializationOutput(serializerFactory) val output = SerializationOutput(serializerFactory)
val outByteArray = output.serialize(newTransaction, context.withExternalSchema(context.externalSchema!!.copy(flush = true)) val outerContext = context.let {
.withEncoding(CordaSerializationEncoding.SNAPPY)).bytes if (encoding != null) {
it.withEncoding(encoding)
} else it
}.let {
if (schemaType != SchemaType.REPEAT) {
it.withExternalSchema(context.externalSchema!!.copy(flush = schemaType == SchemaType.SINGLE))
} else {
it
}
}.let {
if (useDictionary) {
it.withProperty(CordaSerializationEncoding.DICTIONARY_KEY, dict)
} else {
it
}
}
val outByteArray = output.serialize(newTransaction, outerContext /*context.withExternalSchema(context.externalSchema!!.copy(flush = true))
.withEncoding(CordaSerializationEncoding.ZSTANDARD).withProperty(CordaSerializationEncoding.DICTIONARY_KEY, dict)*/).bytes
//val (serializedBytes, schema) = output.serializeAndReturnSchema(transaction, SerializationDefaults.STORAGE_CONTEXT) //val (serializedBytes, schema) = output.serializeAndReturnSchema(transaction, SerializationDefaults.STORAGE_CONTEXT)
println("Output size = ${outByteArray.size}") println("encoding: $encoding, schemaType: $schemaType, useDictionary: $useDictionary, useIntegerFingerprints: $useIntegerFingerprints, Output size = ${outByteArray.size}")
//assertSchemasMatch(envelope.schema, schema) //assertSchemasMatch(envelope.schema, schema)

View File

@ -159,7 +159,7 @@ object KryoCheckpointSerializer : CheckpointSerializer {
context.encoding?.let { encoding -> context.encoding?.let { encoding ->
SectionId.ENCODING.writeTo(this) SectionId.ENCODING.writeTo(this)
(encoding as CordaSerializationEncoding).writeTo(this) (encoding as CordaSerializationEncoding).writeTo(this)
substitute(encoding::wrap) substitute { outputStream -> encoding.wrap(outputStream, context.properties) }
} }
SectionId.ALT_DATA_AND_STOP.writeTo(this) // Forward-compatible in null-encoding case. SectionId.ALT_DATA_AND_STOP.writeTo(this) // Forward-compatible in null-encoding case.
if (context.objectReferencesEnabled) { if (context.objectReferencesEnabled) {

View File

@ -1,6 +1,7 @@
import net.corda.gradle.jarfilter.JarFilterTask import net.corda.gradle.jarfilter.JarFilterTask
import net.corda.gradle.jarfilter.MetaFixerTask import net.corda.gradle.jarfilter.MetaFixerTask
import proguard.gradle.ProGuardTask import proguard.gradle.ProGuardTask
import static org.gradle.api.JavaVersion.VERSION_1_8 import static org.gradle.api.JavaVersion.VERSION_1_8
plugins { plugins {
@ -43,6 +44,7 @@ dependencies {
// These "implementation" dependencies will become "runtime" scoped in our published POM. // These "implementation" dependencies will become "runtime" scoped in our published POM.
implementation "org.iq80.snappy:snappy:$snappy_version" implementation "org.iq80.snappy:snappy:$snappy_version"
implementation "com.github.luben:zstd-jni:1.5.2-5"
implementation "com.google.guava:guava:$guava_version" implementation "com.google.guava:guava:$guava_version"
} }

View File

@ -30,6 +30,9 @@ dependencies {
// Pure-Java Snappy compression // Pure-Java Snappy compression
compile "org.iq80.snappy:snappy:$snappy_version" compile "org.iq80.snappy:snappy:$snappy_version"
// JNI based Zstandard compression
compile "com.github.luben:zstd-jni:1.5.2-5"
// For caches rather than guava // For caches rather than guava
compile "com.github.ben-manes.caffeine:caffeine:$caffeine_version" compile "com.github.ben-manes.caffeine:caffeine:$caffeine_version"

View File

@ -1,5 +1,7 @@
package net.corda.serialization.internal package net.corda.serialization.internal
import com.github.luben.zstd.ZstdInputStream
import com.github.luben.zstd.ZstdOutputStream
import net.corda.core.KeepForDJVM import net.corda.core.KeepForDJVM
import net.corda.core.serialization.SerializationEncoding import net.corda.core.serialization.SerializationEncoding
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
@ -41,20 +43,33 @@ enum class SectionId : OrdinalWriter {
@KeepForDJVM @KeepForDJVM
enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter { enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter {
DEFLATE { DEFLATE {
override fun wrap(stream: OutputStream) = DeflaterOutputStream(stream) override fun wrap(stream: OutputStream, context: Map<Any, Any>) = DeflaterOutputStream(stream)
override fun wrap(stream: InputStream) = InflaterInputStream(stream) override fun wrap(stream: InputStream) = InflaterInputStream(stream)
}, },
SNAPPY { SNAPPY {
override fun wrap(stream: OutputStream) = FlushAverseOutputStream(SnappyFramedOutputStream(stream)) override fun wrap(stream: OutputStream, context: Map<Any, Any>) = FlushAverseOutputStream(SnappyFramedOutputStream(stream))
override fun wrap(stream: InputStream) = SnappyFramedInputStream(stream, false) override fun wrap(stream: InputStream) = SnappyFramedInputStream(stream, false)
},
ZSTANDARD {
override fun wrap(stream: OutputStream, context: Map<Any, Any>) = ZstdOutputStream(stream).apply {
val contextLevel = context[LEVEL_KEY] as Int?
if (contextLevel != null) setLevel(contextLevel)
val contextDictionary = context[DICTIONARY_KEY] as ByteArray?
if (contextDictionary != null) setDict(contextDictionary)
}
override fun wrap(stream: InputStream) = ZstdInputStream(stream)
}; };
companion object { companion object {
val reader = OrdinalReader(values()) val reader = OrdinalReader(values())
val LEVEL_KEY: String = "ZSTD_LEVEL_KEY"
val DICTIONARY_KEY: String = "ZSTD_DICTIONARY"
} }
override val bits = OrdinalBits(ordinal) override val bits = OrdinalBits(ordinal)
abstract fun wrap(stream: OutputStream): OutputStream abstract fun wrap(stream: OutputStream, context: Map<Any, Any>): OutputStream
abstract fun wrap(stream: InputStream): InputStream abstract fun wrap(stream: InputStream): InputStream
} }

View File

@ -106,7 +106,7 @@ open class SerializationOutput constructor(
if (encoding != null) { if (encoding != null) {
SectionId.ENCODING.writeTo(stream) SectionId.ENCODING.writeTo(stream)
(encoding as CordaSerializationEncoding).writeTo(stream) (encoding as CordaSerializationEncoding).writeTo(stream)
stream = encoding.wrap(stream) stream = encoding.wrap(stream, context.properties)
} }
SectionId.DATA_AND_STOP.writeTo(stream) SectionId.DATA_AND_STOP.writeTo(stream)
stream.alsoAsByteBuffer(data.encodedSize().toInt(), data::encode) stream.alsoAsByteBuffer(data.encodedSize().toInt(), data::encode)