Implement zstandard compression. Parameterise the test.

This commit is contained in:
rick.parker 2022-11-04 16:13:36 +00:00
parent 84bf3d5639
commit 137734991b
7 changed files with 112 additions and 11 deletions

View File

@ -1,5 +1,6 @@
package net.corda.finance.flows
import com.github.luben.zstd.ZstdDictTrainer
import net.corda.core.serialization.ExternalSchema
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializationFactory
@ -22,14 +23,57 @@ import net.corda.serialization.internal.amqp.custom.PublicKeySerializer
import net.corda.testing.core.SerializationEnvironmentRule
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
import kotlin.test.fail
enum class SchemaType {
REPEAT,
SINGLE,
NONE;
}
// TODO: If this type of testing gets momentum, we can create a mini-framework that rides through list of files
// and performs necessary validation on all of them.
class CompatibilityTest {
@RunWith(Parameterized::class)
class CompatibilityTest(val encoding: CordaSerializationEncoding?, val schemaType: SchemaType, val useDictionary: Boolean, val useIntegerFingerprints: Boolean) {
companion object {
@Parameterized.Parameters(name = "encoding: {0}, schemaType: {1}, useDictionary: {2}, useIntegerFingerprints: {3}")
@JvmStatic
fun data(): List<Array<Any?>> = listOf(
arrayOf<Any?>(null, SchemaType.REPEAT, false, false),
arrayOf<Any?>(null, SchemaType.SINGLE, false, false),
arrayOf<Any?>(null, SchemaType.NONE, false, false),
arrayOf<Any?>(null, SchemaType.REPEAT, false, true),
arrayOf<Any?>(null, SchemaType.SINGLE, false, true),
arrayOf<Any?>(null, SchemaType.NONE, false, true),
arrayOf<Any?>(CordaSerializationEncoding.DEFLATE, SchemaType.REPEAT, false, false),
arrayOf<Any?>(CordaSerializationEncoding.DEFLATE, SchemaType.SINGLE, false, false),
arrayOf<Any?>(CordaSerializationEncoding.DEFLATE, SchemaType.NONE, false, false),
arrayOf<Any?>(CordaSerializationEncoding.DEFLATE, SchemaType.REPEAT, false, true),
arrayOf<Any?>(CordaSerializationEncoding.DEFLATE, SchemaType.SINGLE, false, true),
arrayOf<Any?>(CordaSerializationEncoding.DEFLATE, SchemaType.NONE, false, true),
arrayOf<Any?>(CordaSerializationEncoding.SNAPPY, SchemaType.REPEAT, false, false),
arrayOf<Any?>(CordaSerializationEncoding.SNAPPY, SchemaType.SINGLE, false, false),
arrayOf<Any?>(CordaSerializationEncoding.SNAPPY, SchemaType.NONE, false, false),
arrayOf<Any?>(CordaSerializationEncoding.SNAPPY, SchemaType.REPEAT, false, true),
arrayOf<Any?>(CordaSerializationEncoding.SNAPPY, SchemaType.SINGLE, false, true),
arrayOf<Any?>(CordaSerializationEncoding.SNAPPY, SchemaType.NONE, false, true),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.REPEAT, false, false),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.SINGLE, false, false),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.NONE, false, false),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.REPEAT, true, false),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.SINGLE, true, false),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.NONE, true, false),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.REPEAT, true, true),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.SINGLE, true, true),
arrayOf<Any?>(CordaSerializationEncoding.ZSTANDARD, SchemaType.NONE, true, true)
)
}
@Rule
@JvmField
@ -47,7 +91,7 @@ class CompatibilityTest {
assertNotNull(inputStream)
val inByteArray: ByteArray = inputStream.readBytes()
println("Original size = ${inByteArray.size}")
//println("Original size = ${inByteArray.size}")
val input = DeserializationInput(serializerFactory)
val (transaction, envelope) = input.deserializeAndReturnEnvelope(
@ -60,7 +104,27 @@ class CompatibilityTest {
assertEquals(1, commands.size)
assertTrue(commands.first().value is Cash.Commands.Issue)
val context = SerializationDefaults.STORAGE_CONTEXT.withExternalSchema(ExternalSchema()).withIntegerFingerprint()
val networkParams = javaClass.classLoader.getResourceAsStream("networkParams.r3corda.6a6b6f256").readBytes()
val trainer = ZstdDictTrainer(128 * 1024, 128 * 1024)
while (useDictionary) {
if (!trainer.addSample(inByteArray)) break
if (!trainer.addSample(networkParams)) break
}
val dict = if (useDictionary) trainer.trainSamples() else ByteArray(0)
val context = SerializationDefaults.STORAGE_CONTEXT.let {
if (schemaType != SchemaType.REPEAT) {
it.withExternalSchema(ExternalSchema())
} else {
it
}
}.let {
if (useIntegerFingerprints) {
it.withIntegerFingerprint()
} else {
it
}
}
//.withExternalSchema(ExternalSchema()).withIntegerFingerprint()
val newWtx = SerializationFactory.defaultFactory.asCurrent {
withCurrentContext(context) {
WireTransaction(transaction.tx.componentGroups.map { cg: ComponentGroup ->
@ -78,10 +142,27 @@ class CompatibilityTest {
// Serialize back and check that representation is byte-to-byte identical to what it was originally.
val output = SerializationOutput(serializerFactory)
val outByteArray = output.serialize(newTransaction, context.withExternalSchema(context.externalSchema!!.copy(flush = true))
.withEncoding(CordaSerializationEncoding.SNAPPY)).bytes
val outerContext = context.let {
if (encoding != null) {
it.withEncoding(encoding)
} else it
}.let {
if (schemaType != SchemaType.REPEAT) {
it.withExternalSchema(context.externalSchema!!.copy(flush = schemaType == SchemaType.SINGLE))
} else {
it
}
}.let {
if (useDictionary) {
it.withProperty(CordaSerializationEncoding.DICTIONARY_KEY, dict)
} else {
it
}
}
val outByteArray = output.serialize(newTransaction, outerContext /*context.withExternalSchema(context.externalSchema!!.copy(flush = true))
.withEncoding(CordaSerializationEncoding.ZSTANDARD).withProperty(CordaSerializationEncoding.DICTIONARY_KEY, dict)*/).bytes
//val (serializedBytes, schema) = output.serializeAndReturnSchema(transaction, SerializationDefaults.STORAGE_CONTEXT)
println("Output size = ${outByteArray.size}")
println("encoding: $encoding, schemaType: $schemaType, useDictionary: $useDictionary, useIntegerFingerprints: $useIntegerFingerprints, Output size = ${outByteArray.size}")
//assertSchemasMatch(envelope.schema, schema)

View File

@ -159,7 +159,7 @@ object KryoCheckpointSerializer : CheckpointSerializer {
context.encoding?.let { encoding ->
SectionId.ENCODING.writeTo(this)
(encoding as CordaSerializationEncoding).writeTo(this)
substitute(encoding::wrap)
substitute { outputStream -> encoding.wrap(outputStream, context.properties) }
}
SectionId.ALT_DATA_AND_STOP.writeTo(this) // Forward-compatible in null-encoding case.
if (context.objectReferencesEnabled) {

View File

@ -1,6 +1,7 @@
import net.corda.gradle.jarfilter.JarFilterTask
import net.corda.gradle.jarfilter.MetaFixerTask
import proguard.gradle.ProGuardTask
import static org.gradle.api.JavaVersion.VERSION_1_8
plugins {
@ -43,6 +44,7 @@ dependencies {
// These "implementation" dependencies will become "runtime" scoped in our published POM.
implementation "org.iq80.snappy:snappy:$snappy_version"
implementation "com.github.luben:zstd-jni:1.5.2-5"
implementation "com.google.guava:guava:$guava_version"
}

View File

@ -30,6 +30,9 @@ dependencies {
// Pure-Java Snappy compression
compile "org.iq80.snappy:snappy:$snappy_version"
// JNI based Zstandard compression
compile "com.github.luben:zstd-jni:1.5.2-5"
// For caches rather than guava
compile "com.github.ben-manes.caffeine:caffeine:$caffeine_version"

View File

@ -1,5 +1,7 @@
package net.corda.serialization.internal
import com.github.luben.zstd.ZstdInputStream
import com.github.luben.zstd.ZstdOutputStream
import net.corda.core.KeepForDJVM
import net.corda.core.serialization.SerializationEncoding
import net.corda.core.utilities.ByteSequence
@ -41,20 +43,33 @@ enum class SectionId : OrdinalWriter {
@KeepForDJVM
enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter {
DEFLATE {
override fun wrap(stream: OutputStream) = DeflaterOutputStream(stream)
override fun wrap(stream: OutputStream, context: Map<Any, Any>) = DeflaterOutputStream(stream)
override fun wrap(stream: InputStream) = InflaterInputStream(stream)
},
SNAPPY {
override fun wrap(stream: OutputStream) = FlushAverseOutputStream(SnappyFramedOutputStream(stream))
override fun wrap(stream: OutputStream, context: Map<Any, Any>) = FlushAverseOutputStream(SnappyFramedOutputStream(stream))
override fun wrap(stream: InputStream) = SnappyFramedInputStream(stream, false)
},
ZSTANDARD {
override fun wrap(stream: OutputStream, context: Map<Any, Any>) = ZstdOutputStream(stream).apply {
val contextLevel = context[LEVEL_KEY] as Int?
if (contextLevel != null) setLevel(contextLevel)
val contextDictionary = context[DICTIONARY_KEY] as ByteArray?
if (contextDictionary != null) setDict(contextDictionary)
}
override fun wrap(stream: InputStream) = ZstdInputStream(stream)
};
companion object {
val reader = OrdinalReader(values())
val LEVEL_KEY: String = "ZSTD_LEVEL_KEY"
val DICTIONARY_KEY: String = "ZSTD_DICTIONARY"
}
override val bits = OrdinalBits(ordinal)
abstract fun wrap(stream: OutputStream): OutputStream
abstract fun wrap(stream: OutputStream, context: Map<Any, Any>): OutputStream
abstract fun wrap(stream: InputStream): InputStream
}

View File

@ -106,7 +106,7 @@ open class SerializationOutput constructor(
if (encoding != null) {
SectionId.ENCODING.writeTo(stream)
(encoding as CordaSerializationEncoding).writeTo(stream)
stream = encoding.wrap(stream)
stream = encoding.wrap(stream, context.properties)
}
SectionId.DATA_AND_STOP.writeTo(stream)
stream.alsoAsByteBuffer(data.encodedSize().toInt(), data::encode)