diff --git a/finance/workflows/src/test/kotlin/net/corda/finance/flows/CompatibilityTest.kt b/finance/workflows/src/test/kotlin/net/corda/finance/flows/CompatibilityTest.kt index b2045d18c6..cfd429e9e7 100644 --- a/finance/workflows/src/test/kotlin/net/corda/finance/flows/CompatibilityTest.kt +++ b/finance/workflows/src/test/kotlin/net/corda/finance/flows/CompatibilityTest.kt @@ -1,5 +1,6 @@ package net.corda.finance.flows +import com.github.luben.zstd.ZstdDictTrainer import net.corda.core.serialization.ExternalSchema import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationFactory @@ -22,14 +23,57 @@ import net.corda.serialization.internal.amqp.custom.PublicKeySerializer import net.corda.testing.core.SerializationEnvironmentRule import org.junit.Rule import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue import kotlin.test.fail +enum class SchemaType { + REPEAT, + SINGLE, + NONE; +} + // TODO: If this type of testing gets momentum, we can create a mini-framework that rides through list of files // and performs necessary validation on all of them. -class CompatibilityTest { +@RunWith(Parameterized::class) +class CompatibilityTest(val encoding: CordaSerializationEncoding?, val schemaType: SchemaType, val useDictionary: Boolean, val useIntegerFingerprints: Boolean) { + + companion object { + @Parameterized.Parameters(name = "encoding: {0}, schemaType: {1}, useDictionary: {2}, useIntegerFingerprints: {3}") + @JvmStatic + fun data(): List> = listOf( + arrayOf(null, SchemaType.REPEAT, false, false), + arrayOf(null, SchemaType.SINGLE, false, false), + arrayOf(null, SchemaType.NONE, false, false), + arrayOf(null, SchemaType.REPEAT, false, true), + arrayOf(null, SchemaType.SINGLE, false, true), + arrayOf(null, SchemaType.NONE, false, true), + arrayOf(CordaSerializationEncoding.DEFLATE, SchemaType.REPEAT, false, false), + arrayOf(CordaSerializationEncoding.DEFLATE, SchemaType.SINGLE, false, false), + arrayOf(CordaSerializationEncoding.DEFLATE, SchemaType.NONE, false, false), + arrayOf(CordaSerializationEncoding.DEFLATE, SchemaType.REPEAT, false, true), + arrayOf(CordaSerializationEncoding.DEFLATE, SchemaType.SINGLE, false, true), + arrayOf(CordaSerializationEncoding.DEFLATE, SchemaType.NONE, false, true), + arrayOf(CordaSerializationEncoding.SNAPPY, SchemaType.REPEAT, false, false), + arrayOf(CordaSerializationEncoding.SNAPPY, SchemaType.SINGLE, false, false), + arrayOf(CordaSerializationEncoding.SNAPPY, SchemaType.NONE, false, false), + arrayOf(CordaSerializationEncoding.SNAPPY, SchemaType.REPEAT, false, true), + arrayOf(CordaSerializationEncoding.SNAPPY, SchemaType.SINGLE, false, true), + arrayOf(CordaSerializationEncoding.SNAPPY, SchemaType.NONE, false, true), + arrayOf(CordaSerializationEncoding.ZSTANDARD, SchemaType.REPEAT, false, false), + arrayOf(CordaSerializationEncoding.ZSTANDARD, SchemaType.SINGLE, false, false), + arrayOf(CordaSerializationEncoding.ZSTANDARD, SchemaType.NONE, false, false), + arrayOf(CordaSerializationEncoding.ZSTANDARD, SchemaType.REPEAT, true, false), + arrayOf(CordaSerializationEncoding.ZSTANDARD, SchemaType.SINGLE, true, false), + arrayOf(CordaSerializationEncoding.ZSTANDARD, SchemaType.NONE, true, false), + arrayOf(CordaSerializationEncoding.ZSTANDARD, SchemaType.REPEAT, true, true), + arrayOf(CordaSerializationEncoding.ZSTANDARD, SchemaType.SINGLE, true, true), + arrayOf(CordaSerializationEncoding.ZSTANDARD, SchemaType.NONE, true, true) + ) + } @Rule @JvmField @@ -47,7 +91,7 @@ class CompatibilityTest { assertNotNull(inputStream) val inByteArray: ByteArray = inputStream.readBytes() - println("Original size = ${inByteArray.size}") + //println("Original size = ${inByteArray.size}") val input = DeserializationInput(serializerFactory) val (transaction, envelope) = input.deserializeAndReturnEnvelope( @@ -60,7 +104,27 @@ class CompatibilityTest { assertEquals(1, commands.size) assertTrue(commands.first().value is Cash.Commands.Issue) - val context = SerializationDefaults.STORAGE_CONTEXT.withExternalSchema(ExternalSchema()).withIntegerFingerprint() + val networkParams = javaClass.classLoader.getResourceAsStream("networkParams.r3corda.6a6b6f256").readBytes() + val trainer = ZstdDictTrainer(128 * 1024, 128 * 1024) + while (useDictionary) { + if (!trainer.addSample(inByteArray)) break + if (!trainer.addSample(networkParams)) break + } + val dict = if (useDictionary) trainer.trainSamples() else ByteArray(0) + val context = SerializationDefaults.STORAGE_CONTEXT.let { + if (schemaType != SchemaType.REPEAT) { + it.withExternalSchema(ExternalSchema()) + } else { + it + } + }.let { + if (useIntegerFingerprints) { + it.withIntegerFingerprint() + } else { + it + } + } + //.withExternalSchema(ExternalSchema()).withIntegerFingerprint() val newWtx = SerializationFactory.defaultFactory.asCurrent { withCurrentContext(context) { WireTransaction(transaction.tx.componentGroups.map { cg: ComponentGroup -> @@ -78,10 +142,27 @@ class CompatibilityTest { // Serialize back and check that representation is byte-to-byte identical to what it was originally. val output = SerializationOutput(serializerFactory) - val outByteArray = output.serialize(newTransaction, context.withExternalSchema(context.externalSchema!!.copy(flush = true)) - .withEncoding(CordaSerializationEncoding.SNAPPY)).bytes + val outerContext = context.let { + if (encoding != null) { + it.withEncoding(encoding) + } else it + }.let { + if (schemaType != SchemaType.REPEAT) { + it.withExternalSchema(context.externalSchema!!.copy(flush = schemaType == SchemaType.SINGLE)) + } else { + it + } + }.let { + if (useDictionary) { + it.withProperty(CordaSerializationEncoding.DICTIONARY_KEY, dict) + } else { + it + } + } + val outByteArray = output.serialize(newTransaction, outerContext /*context.withExternalSchema(context.externalSchema!!.copy(flush = true)) + .withEncoding(CordaSerializationEncoding.ZSTANDARD).withProperty(CordaSerializationEncoding.DICTIONARY_KEY, dict)*/).bytes //val (serializedBytes, schema) = output.serializeAndReturnSchema(transaction, SerializationDefaults.STORAGE_CONTEXT) - println("Output size = ${outByteArray.size}") + println("encoding: $encoding, schemaType: $schemaType, useDictionary: $useDictionary, useIntegerFingerprints: $useIntegerFingerprints, Output size = ${outByteArray.size}") //assertSchemasMatch(envelope.schema, schema) diff --git a/finance/workflows/src/test/resources/networkParams.r3corda.6a6b6f256 b/finance/workflows/src/test/resources/networkParams.r3corda.6a6b6f256 new file mode 100644 index 0000000000..dcdbaa7b5f Binary files /dev/null and b/finance/workflows/src/test/resources/networkParams.r3corda.6a6b6f256 differ diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt index 178682e088..10e42a32bb 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt @@ -159,7 +159,7 @@ object KryoCheckpointSerializer : CheckpointSerializer { context.encoding?.let { encoding -> SectionId.ENCODING.writeTo(this) (encoding as CordaSerializationEncoding).writeTo(this) - substitute(encoding::wrap) + substitute { outputStream -> encoding.wrap(outputStream, context.properties) } } SectionId.ALT_DATA_AND_STOP.writeTo(this) // Forward-compatible in null-encoding case. if (context.objectReferencesEnabled) { diff --git a/serialization-deterministic/build.gradle b/serialization-deterministic/build.gradle index 7822eb3b23..bf0b759041 100644 --- a/serialization-deterministic/build.gradle +++ b/serialization-deterministic/build.gradle @@ -1,6 +1,7 @@ import net.corda.gradle.jarfilter.JarFilterTask import net.corda.gradle.jarfilter.MetaFixerTask import proguard.gradle.ProGuardTask + import static org.gradle.api.JavaVersion.VERSION_1_8 plugins { @@ -43,6 +44,7 @@ dependencies { // These "implementation" dependencies will become "runtime" scoped in our published POM. implementation "org.iq80.snappy:snappy:$snappy_version" + implementation "com.github.luben:zstd-jni:1.5.2-5" implementation "com.google.guava:guava:$guava_version" } diff --git a/serialization/build.gradle b/serialization/build.gradle index 224bd642b4..fbeea0a066 100644 --- a/serialization/build.gradle +++ b/serialization/build.gradle @@ -30,6 +30,9 @@ dependencies { // Pure-Java Snappy compression compile "org.iq80.snappy:snappy:$snappy_version" + // JNI based Zstandard compression + compile "com.github.luben:zstd-jni:1.5.2-5" + // For caches rather than guava compile "com.github.ben-manes.caffeine:caffeine:$caffeine_version" diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationFormat.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationFormat.kt index 7eb236f23d..bb2987d884 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationFormat.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationFormat.kt @@ -1,5 +1,7 @@ package net.corda.serialization.internal +import com.github.luben.zstd.ZstdInputStream +import com.github.luben.zstd.ZstdOutputStream import net.corda.core.KeepForDJVM import net.corda.core.serialization.SerializationEncoding import net.corda.core.utilities.ByteSequence @@ -41,20 +43,33 @@ enum class SectionId : OrdinalWriter { @KeepForDJVM enum class CordaSerializationEncoding : SerializationEncoding, OrdinalWriter { DEFLATE { - override fun wrap(stream: OutputStream) = DeflaterOutputStream(stream) + override fun wrap(stream: OutputStream, context: Map) = DeflaterOutputStream(stream) override fun wrap(stream: InputStream) = InflaterInputStream(stream) }, SNAPPY { - override fun wrap(stream: OutputStream) = FlushAverseOutputStream(SnappyFramedOutputStream(stream)) + override fun wrap(stream: OutputStream, context: Map) = FlushAverseOutputStream(SnappyFramedOutputStream(stream)) override fun wrap(stream: InputStream) = SnappyFramedInputStream(stream, false) + }, + ZSTANDARD { + override fun wrap(stream: OutputStream, context: Map) = ZstdOutputStream(stream).apply { + val contextLevel = context[LEVEL_KEY] as Int? + if (contextLevel != null) setLevel(contextLevel) + val contextDictionary = context[DICTIONARY_KEY] as ByteArray? + if (contextDictionary != null) setDict(contextDictionary) + } + + override fun wrap(stream: InputStream) = ZstdInputStream(stream) }; companion object { val reader = OrdinalReader(values()) + + val LEVEL_KEY: String = "ZSTD_LEVEL_KEY" + val DICTIONARY_KEY: String = "ZSTD_DICTIONARY" } override val bits = OrdinalBits(ordinal) - abstract fun wrap(stream: OutputStream): OutputStream + abstract fun wrap(stream: OutputStream, context: Map): OutputStream abstract fun wrap(stream: InputStream): InputStream } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationOutput.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationOutput.kt index e6786ccd52..9c02f2c490 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationOutput.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationOutput.kt @@ -106,7 +106,7 @@ open class SerializationOutput constructor( if (encoding != null) { SectionId.ENCODING.writeTo(stream) (encoding as CordaSerializationEncoding).writeTo(stream) - stream = encoding.wrap(stream) + stream = encoding.wrap(stream, context.properties) } SectionId.DATA_AND_STOP.writeTo(stream) stream.alsoAsByteBuffer(data.encodedSize().toInt(), data::encode)