From c33720c73d0fd40b8d98bcbb63aafb0d7ed8bd5b Mon Sep 17 00:00:00 2001 From: Joseph Zuniga-Daly <59851625+josephzunigadaly@users.noreply.github.com> Date: Wed, 22 Jul 2020 17:31:59 +0100 Subject: [PATCH] CORDA-3717: Apply custom serializers to checkpoints (#6392) * CORDA-3717: Apply custom serializers to checkpoints * Remove try/catch to fix TooGenericExceptionCaught detekt rule * Rename exception * Extract method * Put calls to the userSerializer on their own lines to improve readability * Remove unused constructors from exception * Remove unused proxyType field * Give field a descriptive name * Explain why we are looking for two type parameters when we only use one * Tidy up the fetching of types * Use 0 seconds when forcing a flow checkpoint inside test * Add test to check references are restored correctly * Add CheckpointCustomSerializer interface * Wire up the new CheckpointCustomSerializer interface * Use kryo default for abstract classes * Remove unused imports * Remove need for external library in tests * Make file match original to remove from diff * Remove maySkipCheckpoint from calls to sleep * Add newline to end of file * Test custom serializers mapped to interfaces * Test serializer configured with abstract class * Move test into its own package * Rename test * Move flows and serializers into their own source file * Move broken map into its own source file * Delete comment now source file is simpler * Rename class to have a shorter name * Add tests that run the checkpoint serializer directly * Check serialization of final classes * Register as default unless the target class is final * Test PublicKey serializer has not been overridden * Add a broken serializer for EdDSAPublicKey to make test more robust * Split serializer registration into default and non-default registrations. Run registrations at the right time to preserve Cordas own custom serializers. * Check for duplicate custom checkpoint serializers * Add doc comments * Add doc comments to CustomSerializerCheckpointAdaptor * Add test to check duplicate serializers are logged * Do not log the duplicate serializer warning when the duplicate is the same class * Update doc comment for CheckpointCustomSerializer * Sort serializers by classname so we are not registering in an unknown or random order * Add test to serialize a class that references itself * Store custom serializer type in the Kryo stream so we can spot when a different serializer is being used to deserialize * Testing has shown that registering custom serializers as default is more robust when adding new cordapps * Remove new line character * Remove unused imports * Add interface net.corda.core.serialization.CheckpointCustomSerializer to api-current.txt * Remove comment * Update comment on exception * Make CustomSerializerCheckpointAdaptor internal * Revert "Add interface net.corda.core.serialization.CheckpointCustomSerializer to api-current.txt" This reverts commit b835de79bd21f0048be741e7fc5f0c3088516d2b. * Restore "Add interface net.corda.core.serialization.CheckpointCustomSerializer to api-current.txt"" This reverts commit 718873a4e963bad4e327bb200e7bb4de44bc47ad. * Pass the class loader instead of the context * Do less work in test setup * Make the serialization context unique for CustomCheckpointSerializerTest so we get a new Kryo pool for the test * Rebuild the Kryo pool for the given context when we change custom serializers * Rebuild all Kryo pools on serializer change to keep serializer list consistent * Move the custom serializer list into CheckpointSerializationContext to reduce scope from global to a serialization context * Remove unused imports * Make the new checkpointCustomSerializers property default to the empty list * Delegate implementation using kotlin language feature --- .ci/api-current.txt | 4 + .../kotlin/net/corda/core/cordapp/Cordapp.kt | 3 + .../core/internal/cordapp/CordappImpl.kt | 3 + .../SerializationCustomSerializer.kt | 23 ++ .../internal/CheckpointSerializationAPI.kt | 9 + .../kryo/CustomSerializerCheckpointAdaptor.kt | 103 +++++++++ .../kryo/KryoCheckpointSerializer.kt | 54 ++++- .../CustomCheckpointSerializerTest.kt | 99 ++++++++ .../DifficultToSerialize.kt | 27 +++ .../DuplicateSerializerLogTest.kt | 59 +++++ ...cateSerializerLogWithSameSerializerTest.kt | 58 +++++ ...ckNetworkCustomCheckpointSerializerTest.kt | 75 ++++++ .../ReferenceLoopTest.kt | 75 ++++++ .../customcheckpointserializer/TestCorDapp.kt | 214 ++++++++++++++++++ .../kotlin/net/corda/node/internal/Node.kt | 4 +- .../cordapp/JarScanningCordappLoader.kt | 6 + .../node/internal/cordapp/VirtualCordapps.kt | 4 + .../internal/CheckpointSerializationScheme.kt | 6 +- .../InternalSerializationTestHelpers.kt | 6 +- 19 files changed, 826 insertions(+), 6 deletions(-) create mode 100644 node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomSerializerCheckpointAdaptor.kt create mode 100644 node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/CustomCheckpointSerializerTest.kt create mode 100644 node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/DifficultToSerialize.kt create mode 100644 node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/DuplicateSerializerLogTest.kt create mode 100644 node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/DuplicateSerializerLogWithSameSerializerTest.kt create mode 100644 node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/MockNetworkCustomCheckpointSerializerTest.kt create mode 100644 node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/ReferenceLoopTest.kt create mode 100644 node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/TestCorDapp.kt diff --git a/.ci/api-current.txt b/.ci/api-current.txt index 64e351610e..10374f09e3 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -5398,6 +5398,10 @@ public interface net.corda.core.schemas.QueryableState extends net.corda.core.co ## public interface net.corda.core.schemas.StatePersistable ## +public interface net.corda.core.serialization.CheckpointCustomSerializer + public abstract OBJ fromProxy(PROXY) + public abstract PROXY toProxy(OBJ) +## public interface net.corda.core.serialization.ClassWhitelist public abstract boolean hasListed(Class) ## diff --git a/core/src/main/kotlin/net/corda/core/cordapp/Cordapp.kt b/core/src/main/kotlin/net/corda/core/cordapp/Cordapp.kt index 753e842fe6..1dd153e0ae 100644 --- a/core/src/main/kotlin/net/corda/core/cordapp/Cordapp.kt +++ b/core/src/main/kotlin/net/corda/core/cordapp/Cordapp.kt @@ -7,6 +7,7 @@ import net.corda.core.crypto.SecureHash import net.corda.core.flows.FlowLogic import net.corda.core.internal.cordapp.CordappImpl.Companion.UNKNOWN_VALUE import net.corda.core.schemas.MappedSchema +import net.corda.core.serialization.CheckpointCustomSerializer import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.SerializeAsToken @@ -29,6 +30,7 @@ import java.net.URL * @property services List of RPC services * @property serializationWhitelists List of Corda plugin registries * @property serializationCustomSerializers List of serializers + * @property checkpointCustomSerializers List of serializers for checkpoints * @property customSchemas List of custom schemas * @property allFlows List of all flow classes * @property jarPath The path to the JAR for this CorDapp @@ -49,6 +51,7 @@ interface Cordapp { val services: List> val serializationWhitelists: List val serializationCustomSerializers: List> + val checkpointCustomSerializers: List> val customSchemas: Set val allFlows: List>> val jarPath: URL diff --git a/core/src/main/kotlin/net/corda/core/internal/cordapp/CordappImpl.kt b/core/src/main/kotlin/net/corda/core/internal/cordapp/CordappImpl.kt index d511ba7860..1c5d69e511 100644 --- a/core/src/main/kotlin/net/corda/core/internal/cordapp/CordappImpl.kt +++ b/core/src/main/kotlin/net/corda/core/internal/cordapp/CordappImpl.kt @@ -9,6 +9,7 @@ import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.notary.NotaryService import net.corda.core.internal.toPath import net.corda.core.schemas.MappedSchema +import net.corda.core.serialization.CheckpointCustomSerializer import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.SerializeAsToken @@ -25,6 +26,7 @@ data class CordappImpl( override val services: List>, override val serializationWhitelists: List, override val serializationCustomSerializers: List>, + override val checkpointCustomSerializers: List>, override val customSchemas: Set, override val allFlows: List>>, override val jarPath: URL, @@ -79,6 +81,7 @@ data class CordappImpl( services = emptyList(), serializationWhitelists = emptyList(), serializationCustomSerializers = emptyList(), + checkpointCustomSerializers = emptyList(), customSchemas = emptySet(), jarPath = Paths.get("").toUri().toURL(), info = UNKNOWN_INFO, diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationCustomSerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationCustomSerializer.kt index d0c910b638..ed387f8f94 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationCustomSerializer.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationCustomSerializer.kt @@ -25,3 +25,26 @@ interface SerializationCustomSerializer { */ fun fromProxy(proxy: PROXY): OBJ } + +/** + * Allows CorDapps to provide custom serializers for classes that do not serialize successfully during a checkpoint. + * In this case, a proxy serializer can be written that implements this interface whose purpose is to move between + * unserializable types and an intermediate representation. + * + * NOTE: Only implement this interface if you have a class that triggers an error during normal checkpoint + * serialization/deserialization. + */ +@KeepForDJVM +interface CheckpointCustomSerializer { + /** + * Should facilitate the conversion of the third party object into the serializable + * local class specified by [PROXY] + */ + fun toProxy(obj: OBJ): PROXY + + /** + * Should facilitate the conversion of the proxy object into a new instance of the + * unserializable type + */ + fun fromProxy(proxy: PROXY): OBJ +} diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationAPI.kt index 98fdcd730d..510986141c 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationAPI.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationAPI.kt @@ -56,6 +56,10 @@ interface CheckpointSerializationContext { * otherwise they appear as new copies of the object. */ val objectReferencesEnabled: Boolean + /** + * User defined custom serializers for use in checkpoint serialization. + */ + val checkpointCustomSerializers: Iterable> /** * Helper method to return a new context based on this context with the property added. @@ -86,6 +90,11 @@ interface CheckpointSerializationContext { * A shallow copy of this context but with the given encoding whitelist. */ fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist): CheckpointSerializationContext + + /** + * A shallow copy of this context but with the given custom serializers. + */ + fun withCheckpointCustomSerializers(checkpointCustomSerializers: Iterable>): CheckpointSerializationContext } /* diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomSerializerCheckpointAdaptor.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomSerializerCheckpointAdaptor.kt new file mode 100644 index 0000000000..4f3475696b --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomSerializerCheckpointAdaptor.kt @@ -0,0 +1,103 @@ +package net.corda.nodeapi.internal.serialization.kryo + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.Serializer +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import net.corda.core.serialization.CheckpointCustomSerializer +import net.corda.serialization.internal.amqp.CORDAPP_TYPE +import java.lang.reflect.Type +import kotlin.reflect.jvm.javaType +import kotlin.reflect.jvm.jvmErasure + +/** + * Adapts CheckpointCustomSerializer for use in Kryo + */ +internal class CustomSerializerCheckpointAdaptor(private val userSerializer : CheckpointCustomSerializer) : Serializer() { + + /** + * The class name of the serializer we are adapting. + */ + val serializerName: String = userSerializer.javaClass.name + + /** + * The input type of this custom serializer. + */ + val cordappType: Type + + /** + * Check we have access to the types specified on the CheckpointCustomSerializer interface. + * + * Throws UnableToDetermineSerializerTypesException if the types are missing. + */ + init { + val types: List = userSerializer::class + .supertypes + .filter { it.jvmErasure == CheckpointCustomSerializer::class } + .flatMap { it.arguments } + .mapNotNull { it.type?.javaType } + + // We are expecting a cordapp type and a proxy type. + // We will only use the cordapp type in this class + // but we want to check both are present. + val typeParameterCount = 2 + if (types.size != typeParameterCount) { + throw UnableToDetermineSerializerTypesException("Unable to determine serializer parent types") + } + cordappType = types[CORDAPP_TYPE] + } + + /** + * Serialize obj to the Kryo stream. + */ + override fun write(kryo: Kryo, output: Output, obj: OBJ) { + + fun writeToKryo(obj: T) = kryo.writeClassAndObject(output, obj) + + // Write serializer type + writeToKryo(serializerName) + + // Write proxy object + writeToKryo(userSerializer.toProxy(obj)) + } + + /** + * Deserialize an object from the Kryo stream. + */ + override fun read(kryo: Kryo, input: Input, type: Class): OBJ { + + @Suppress("UNCHECKED_CAST") + fun readFromKryo() = kryo.readClassAndObject(input) as T + + // Check the serializer type + checkSerializerType(readFromKryo()) + + // Read the proxy object + return userSerializer.fromProxy(readFromKryo()) + } + + /** + * Throws a `CustomCheckpointSerializersHaveChangedException` if the serializer type in the kryo stream does not match the serializer + * type for this custom serializer. + * + * @param checkpointSerializerType Serializer type from the Kryo stream + */ + private fun checkSerializerType(checkpointSerializerType: String) { + if (checkpointSerializerType != serializerName) + throw CustomCheckpointSerializersHaveChangedException("The custom checkpoint serializers have changed while checkpoints exist. " + + "Please restore the CorDapps to when this checkpoint was created.") + } +} + +/** + * Thrown when the input/output types are missing from the custom serializer. + */ +class UnableToDetermineSerializerTypesException(message: String) : RuntimeException(message) + +/** + * Thrown when the custom serializer is found to be reading data from another type of custom serializer. + * + * This was expected to happen if the user adds or removes CorDapps while checkpoints exist but it turned out that registering serializers + * as default made the system reliable. + */ +class CustomCheckpointSerializersHaveChangedException(message: String) : RuntimeException(message) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt index 6a73119ce6..06698d99ad 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt @@ -10,12 +10,14 @@ import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.serializers.ClosureSerializer import net.corda.core.internal.uncheckedCast +import net.corda.core.serialization.CheckpointCustomSerializer import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.internal.CheckpointSerializer import net.corda.core.utilities.ByteSequence +import net.corda.core.utilities.loggerFor import net.corda.serialization.internal.AlwaysAcceptEncodingWhitelist import net.corda.serialization.internal.ByteBufferInputStream import net.corda.serialization.internal.CheckpointSerializationContextImpl @@ -40,10 +42,10 @@ private object AutoCloseableSerialisationDetector : Serializer() } object KryoCheckpointSerializer : CheckpointSerializer { - private val kryoPoolsForContexts = ConcurrentHashMap, KryoPool>() + private val kryoPoolsForContexts = ConcurrentHashMap>>, KryoPool>() private fun getPool(context: CheckpointSerializationContext): KryoPool { - return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { + return kryoPoolsForContexts.computeIfAbsent(Triple(context.whitelist, context.deserializationClassLoader, context.checkpointCustomSerializers)) { KryoPool.Builder { val serializer = Fiber.getFiberSerializer(false) as KryoSerializer val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) } @@ -56,12 +58,60 @@ object KryoCheckpointSerializer : CheckpointSerializer { addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector) register(ClosureSerializer.Closure::class.java, CordaClosureSerializer) classLoader = it.second + + // Add custom serializers + val customSerializers = buildCustomSerializerAdaptors(context) + warnAboutDuplicateSerializers(customSerializers) + val classToSerializer = mapInputClassToCustomSerializer(context.deserializationClassLoader, customSerializers) + addDefaultCustomSerializers(this, classToSerializer) } }.build() } } + /** + * Returns a sorted list of CustomSerializerCheckpointAdaptor based on the custom serializers inside context. + * + * The adaptors are sorted by serializerName which maps to javaClass.name for the serializer class + */ + private fun buildCustomSerializerAdaptors(context: CheckpointSerializationContext) = + context.checkpointCustomSerializers.map { CustomSerializerCheckpointAdaptor(it) }.sortedBy { it.serializerName } + + /** + * Returns a list of pairs where the first element is the input class of the custom serializer and the second element is the + * custom serializer. + */ + private fun mapInputClassToCustomSerializer(classLoader: ClassLoader, customSerializers: Iterable>) = + customSerializers.map { getInputClassForCustomSerializer(classLoader, it) to it } + + /** + * Returns the Class object for the serializers input type. + */ + private fun getInputClassForCustomSerializer(classLoader: ClassLoader, customSerializer: CustomSerializerCheckpointAdaptor<*, *>): Class<*> { + val typeNameWithoutGenerics = customSerializer.cordappType.typeName.substringBefore('<') + return classLoader.loadClass(typeNameWithoutGenerics) + } + + /** + * Emit a warning if two or more custom serializers are found for the same input type. + */ + private fun warnAboutDuplicateSerializers(customSerializers: Iterable>) = + customSerializers + .groupBy({ it.cordappType }, { it.serializerName }) + .filter { (_, serializerNames) -> serializerNames.distinct().size > 1 } + .forEach { (inputType, serializerNames) -> loggerFor().warn("Duplicate custom checkpoint serializer for type $inputType. Serializers: ${serializerNames.joinToString(", ")}") } + + /** + * Register all custom serializers as default, this class + subclass, registrations. + * + * Serializers registered before this will take priority. This needs to run after registrations we want to keep otherwise it may + * replace them. + */ + private fun addDefaultCustomSerializers(kryo: Kryo, classToSerializer: Iterable, CustomSerializerCheckpointAdaptor<*, *>>>) = + classToSerializer + .forEach { (clazz, customSerializer) -> kryo.addDefaultSerializer(clazz, customSerializer) } + private fun CheckpointSerializationContext.kryo(task: Kryo.() -> T): T { return getPool(this).run { kryo -> kryo.context.ensureCapacity(properties.size) diff --git a/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/CustomCheckpointSerializerTest.kt b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/CustomCheckpointSerializerTest.kt new file mode 100644 index 0000000000..0efb030fff --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/CustomCheckpointSerializerTest.kt @@ -0,0 +1,99 @@ +package net.corda.node.customcheckpointserializer + +import com.nhaarman.mockito_kotlin.doReturn +import com.nhaarman.mockito_kotlin.whenever +import net.corda.core.crypto.generateKeyPair +import net.corda.core.serialization.EncodingWhitelist +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.checkpointDeserialize +import net.corda.core.serialization.internal.checkpointSerialize +import net.corda.coretesting.internal.rigorousMock +import net.corda.serialization.internal.AllWhitelist +import net.corda.serialization.internal.CheckpointSerializationContextImpl +import net.corda.serialization.internal.CordaSerializationEncoding +import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule +import org.junit.Assert +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +@RunWith(Parameterized::class) +class CustomCheckpointSerializerTest(private val compression: CordaSerializationEncoding?) { + companion object { + @Parameterized.Parameters(name = "{0}") + @JvmStatic + fun compression() = arrayOf(null) + CordaSerializationEncoding.values() + } + + @get:Rule + val serializationRule = CheckpointSerializationEnvironmentRule(inheritable = true) + private val context: CheckpointSerializationContext = CheckpointSerializationContextImpl( + deserializationClassLoader = javaClass.classLoader, + whitelist = AllWhitelist, + properties = emptyMap(), + objectReferencesEnabled = true, + encoding = compression, + encodingWhitelist = rigorousMock().also { + if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression) + }, + checkpointCustomSerializers = listOf( + TestCorDapp.TestAbstractClassSerializer(), + TestCorDapp.TestClassSerializer(), + TestCorDapp.TestInterfaceSerializer(), + TestCorDapp.TestFinalClassSerializer(), + TestCorDapp.BrokenPublicKeySerializer() + ) + ) + + @Test(timeout=300_000) + fun `test custom checkpoint serialization`() { + testBrokenMapSerialization(DifficultToSerialize.BrokenMapClass()) + } + + @Test(timeout=300_000) + fun `test custom checkpoint serialization using interface`() { + testBrokenMapSerialization(DifficultToSerialize.BrokenMapInterfaceImpl()) + } + + @Test(timeout=300_000) + fun `test custom checkpoint serialization using abstract class`() { + testBrokenMapSerialization(DifficultToSerialize.BrokenMapAbstractImpl()) + } + + @Test(timeout=300_000) + fun `test custom checkpoint serialization using final class`() { + testBrokenMapSerialization(DifficultToSerialize.BrokenMapFinal()) + } + + @Test(timeout=300_000) + fun `test PublicKey serializer has not been overridden`() { + + val publicKey = generateKeyPair().public + + // Serialize/deserialize + val checkpoint = publicKey.checkpointSerialize(context) + val deserializedCheckpoint = checkpoint.checkpointDeserialize(context) + + // Check the elements are as expected + Assert.assertArrayEquals(publicKey.encoded, deserializedCheckpoint.encoded) + } + + + private fun testBrokenMapSerialization(brokenMap : MutableMap): MutableMap { + // Add elements to the map + brokenMap.putAll(mapOf("key" to "value")) + + // Serialize/deserialize + val checkpoint = brokenMap.checkpointSerialize(context) + val deserializedCheckpoint = checkpoint.checkpointDeserialize(context) + + // Check the elements are as expected + Assert.assertEquals(1, deserializedCheckpoint.size) + Assert.assertEquals("value", deserializedCheckpoint.get("key")) + + // Return map for extra checks + return deserializedCheckpoint + } +} + diff --git a/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/DifficultToSerialize.kt b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/DifficultToSerialize.kt new file mode 100644 index 0000000000..f272e71ebf --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/DifficultToSerialize.kt @@ -0,0 +1,27 @@ +package net.corda.node.customcheckpointserializer + +import net.corda.core.flows.FlowException + +class DifficultToSerialize { + + // Broken Map + // This map breaks the rules for the put method. Making the normal map serializer fail. + + open class BrokenMapBaseImpl(delegate: MutableMap = mutableMapOf()) : MutableMap by delegate { + override fun put(key: K, value: V): V? = throw FlowException("Broken on purpose") + } + + // A class to test custom serializers applied to implementations + class BrokenMapClass : BrokenMapBaseImpl() + + // An interface and implementation to test custom serializers applied to interface types + interface BrokenMapInterface : MutableMap + class BrokenMapInterfaceImpl : BrokenMapBaseImpl(), BrokenMapInterface + + // An abstract class and implementation to test custom serializers applied to interface types + abstract class BrokenMapAbstract : BrokenMapBaseImpl(), MutableMap + class BrokenMapAbstractImpl : BrokenMapAbstract() + + // A final class + final class BrokenMapFinal: BrokenMapBaseImpl() +} diff --git a/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/DuplicateSerializerLogTest.kt b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/DuplicateSerializerLogTest.kt new file mode 100644 index 0000000000..2f87e1005f --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/DuplicateSerializerLogTest.kt @@ -0,0 +1,59 @@ +package net.corda.node.customcheckpointserializer + +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.StartableByRPC +import net.corda.core.messaging.startFlow +import net.corda.core.serialization.CheckpointCustomSerializer +import net.corda.core.utilities.getOrThrow +import net.corda.node.logging.logFile +import net.corda.testing.driver.driver +import org.assertj.core.api.Assertions +import org.junit.Test +import java.time.Duration + +class DuplicateSerializerLogTest{ + @Test(timeout=300_000) + fun `check duplicate serialisers are logged`() { + driver { + val node = startNode(startInSameProcess = false).getOrThrow() + node.rpc.startFlow(::TestFlow).returnValue.get() + + val text = node.logFile().readLines().filter { it.startsWith("[WARN") } + + // Initial message is correct + Assertions.assertThat(text).anyMatch {it.contains("Duplicate custom checkpoint serializer for type net.corda.node.customcheckpointserializer.DifficultToSerialize\$BrokenMapInterface. Serializers: ")} + // Message mentions TestInterfaceSerializer + Assertions.assertThat(text).anyMatch {it.contains("net.corda.node.customcheckpointserializer.TestCorDapp\$TestInterfaceSerializer")} + // Message mentions DuplicateSerializer + Assertions.assertThat(text).anyMatch {it.contains("net.corda.node.customcheckpointserializer.DuplicateSerializerLogTest\$DuplicateSerializer")} + } + } + + @StartableByRPC + @InitiatingFlow + class TestFlow : FlowLogic>() { + override fun call(): DifficultToSerialize.BrokenMapInterface { + val brokenMap: DifficultToSerialize.BrokenMapInterface = DifficultToSerialize.BrokenMapInterfaceImpl() + brokenMap.putAll(mapOf("test" to "input")) + + sleep(Duration.ofSeconds(0)) + + return brokenMap + } + } + + @Suppress("unused") + class DuplicateSerializer : + CheckpointCustomSerializer, HashMap> { + + override fun toProxy(obj: DifficultToSerialize.BrokenMapInterface): HashMap { + val proxy = HashMap() + return obj.toMap(proxy) + } + override fun fromProxy(proxy: HashMap): DifficultToSerialize.BrokenMapInterface { + return DifficultToSerialize.BrokenMapInterfaceImpl() + .also { it.putAll(proxy) } + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/DuplicateSerializerLogWithSameSerializerTest.kt b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/DuplicateSerializerLogWithSameSerializerTest.kt new file mode 100644 index 0000000000..598b1ed401 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/DuplicateSerializerLogWithSameSerializerTest.kt @@ -0,0 +1,58 @@ +package net.corda.node.customcheckpointserializer + +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.StartableByRPC +import net.corda.core.messaging.startFlow +import net.corda.core.serialization.CheckpointCustomSerializer +import net.corda.core.serialization.CordaSerializable +import net.corda.core.utilities.getOrThrow +import net.corda.node.logging.logFile +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.node.internal.enclosedCordapp +import org.assertj.core.api.Assertions +import org.junit.Test +import java.time.Duration + +class DuplicateSerializerLogWithSameSerializerTest { + @Test(timeout=300_000) + fun `check duplicate serialisers are logged not logged for the same class`() { + + // Duplicate the cordapp in this node + driver(DriverParameters(cordappsForAllNodes = listOf(this.enclosedCordapp(), this.enclosedCordapp()))) { + val node = startNode(startInSameProcess = false).getOrThrow() + node.rpc.startFlow(::TestFlow).returnValue.get() + + val text = node.logFile().readLines().filter { it.startsWith("[WARN") } + + // Initial message is not logged + Assertions.assertThat(text) + .anyMatch { !it.contains("Duplicate custom checkpoint serializer for type ") } + // Log does not mention DuplicateSerializerThatShouldNotBeLogged + Assertions.assertThat(text) + .anyMatch { !it.contains("DuplicateSerializerThatShouldNotBeLogged") } + } + } + + @CordaSerializable + class UnusedClass + + @Suppress("unused") + class DuplicateSerializerThatShouldNotBeLogged : CheckpointCustomSerializer { + override fun toProxy(obj: UnusedClass): String = "" + override fun fromProxy(proxy: String): UnusedClass = UnusedClass() + } + + @StartableByRPC + @InitiatingFlow + class TestFlow : FlowLogic() { + override fun call(): UnusedClass { + val unusedClass = UnusedClass() + + sleep(Duration.ofSeconds(0)) + + return unusedClass + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/MockNetworkCustomCheckpointSerializerTest.kt b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/MockNetworkCustomCheckpointSerializerTest.kt new file mode 100644 index 0000000000..5bd60293c4 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/MockNetworkCustomCheckpointSerializerTest.kt @@ -0,0 +1,75 @@ +package net.corda.node.customcheckpointserializer + +import co.paralleluniverse.fibers.Suspendable +import net.corda.testing.node.MockNetwork +import net.corda.testing.node.MockNetworkParameters +import org.assertj.core.api.Assertions +import org.junit.After +import org.junit.Before +import org.junit.Test + +class MockNetworkCustomCheckpointSerializerTest { + private lateinit var mockNetwork: MockNetwork + + @Before + fun setup() { + mockNetwork = MockNetwork(MockNetworkParameters(cordappsForAllNodes = listOf(TestCorDapp.getCorDapp()))) + } + + @After + fun shutdown() { + mockNetwork.stopNodes() + } + + @Test(timeout = 300_000) + fun `flow suspend with custom kryo serializer`() { + val node = mockNetwork.createPartyNode() + val expected = 5 + val actual = node.startFlow(TestCorDapp.TestFlowWithDifficultToSerializeLocalVariable(5)).get() + + Assertions.assertThat(actual).isEqualTo(expected) + } + + @Test(timeout = 300_000) + fun `check references are restored correctly`() { + val node = mockNetwork.createPartyNode() + val expectedReference = DifficultToSerialize.BrokenMapClass() + expectedReference.putAll(mapOf("one" to 1)) + val actualReference = node.startFlow(TestCorDapp.TestFlowCheckingReferencesWork(expectedReference)).get() + + Assertions.assertThat(actualReference).isSameAs(expectedReference) + Assertions.assertThat(actualReference["one"]).isEqualTo(1) + } + + @Test(timeout = 300_000) + @Suspendable + fun `check serialization of interfaces`() { + val node = mockNetwork.createPartyNode() + val result = node.startFlow(TestCorDapp.TestFlowWithDifficultToSerializeLocalVariableAsInterface(5)).get() + Assertions.assertThat(result).isEqualTo(5) + } + + @Test(timeout = 300_000) + @Suspendable + fun `check serialization of abstract classes`() { + val node = mockNetwork.createPartyNode() + val result = node.startFlow(TestCorDapp.TestFlowWithDifficultToSerializeLocalVariableAsAbstract(5)).get() + Assertions.assertThat(result).isEqualTo(5) + } + + @Test(timeout = 300_000) + @Suspendable + fun `check serialization of final classes`() { + val node = mockNetwork.createPartyNode() + val result = node.startFlow(TestCorDapp.TestFlowWithDifficultToSerializeLocalVariableAsFinal(5)).get() + Assertions.assertThat(result).isEqualTo(5) + } + + @Test(timeout = 300_000) + @Suspendable + fun `check PublicKey serializer has not been overridden`() { + val node = mockNetwork.createPartyNode() + val result = node.startFlow(TestCorDapp.TestFlowCheckingPublicKeySerializer()).get() + Assertions.assertThat(result.encoded).isEqualTo(node.info.legalIdentities.first().owningKey.encoded) + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/ReferenceLoopTest.kt b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/ReferenceLoopTest.kt new file mode 100644 index 0000000000..92a8d396c4 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/ReferenceLoopTest.kt @@ -0,0 +1,75 @@ +package net.corda.node.customcheckpointserializer + +import com.nhaarman.mockito_kotlin.doReturn +import com.nhaarman.mockito_kotlin.whenever +import net.corda.core.serialization.CheckpointCustomSerializer +import net.corda.core.serialization.EncodingWhitelist +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.checkpointDeserialize +import net.corda.core.serialization.internal.checkpointSerialize +import net.corda.coretesting.internal.rigorousMock +import net.corda.serialization.internal.AllWhitelist +import net.corda.serialization.internal.CheckpointSerializationContextImpl +import net.corda.serialization.internal.CordaSerializationEncoding +import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule +import org.junit.Assert +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized + +@RunWith(Parameterized::class) +class ReferenceLoopTest(private val compression: CordaSerializationEncoding?) { + companion object { + @Parameterized.Parameters(name = "{0}") + @JvmStatic + fun compression() = arrayOf(null) + CordaSerializationEncoding.values() + } + + @get:Rule + val serializationRule = CheckpointSerializationEnvironmentRule(inheritable = true) + private val context: CheckpointSerializationContext = CheckpointSerializationContextImpl( + deserializationClassLoader = javaClass.classLoader, + whitelist = AllWhitelist, + properties = emptyMap(), + objectReferencesEnabled = true, + encoding = compression, + encodingWhitelist = rigorousMock() + .also { + if (compression != null) doReturn(true).whenever(it) + .acceptEncoding(compression) + }, + checkpointCustomSerializers = listOf(PersonSerializer())) + + @Test(timeout=300_000) + fun `custom checkpoint serialization with reference loop`() { + val person = Person("Test name") + + val result = person.checkpointSerialize(context).checkpointDeserialize(context) + + Assert.assertEquals("Test name", result.name) + Assert.assertEquals("Test name", result.bestFriend.name) + Assert.assertSame(result, result.bestFriend) + } + + /** + * Test class that will hold a reference to itself + */ + class Person(val name: String, bestFriend: Person? = null) { + val bestFriend: Person = bestFriend ?: this + } + + /** + * Custom serializer for the Person class + */ + @Suppress("unused") + class PersonSerializer : CheckpointCustomSerializer> { + override fun toProxy(obj: Person): Map { + return mapOf("name" to obj.name, "bestFriend" to obj.bestFriend) + } + + override fun fromProxy(proxy: Map): Person { + return Person(proxy["name"] as String, proxy["bestFriend"] as Person?) + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/TestCorDapp.kt b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/TestCorDapp.kt new file mode 100644 index 0000000000..1d3e929dde --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/customcheckpointserializer/TestCorDapp.kt @@ -0,0 +1,214 @@ +package net.corda.node.customcheckpointserializer + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.FlowException +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StartableByRPC +import net.corda.core.serialization.CheckpointCustomSerializer +import net.corda.testing.node.internal.CustomCordapp +import net.corda.testing.node.internal.enclosedCordapp +import net.i2p.crypto.eddsa.EdDSAPublicKey +import org.assertj.core.api.Assertions +import java.security.PublicKey +import java.time.Duration + +/** + * Contains all the flows and custom serializers for testing custom checkpoint serializers + */ +class TestCorDapp { + + companion object { + fun getCorDapp(): CustomCordapp = enclosedCordapp() + } + + // Flows + @StartableByRPC + class TestFlowWithDifficultToSerializeLocalVariableAsAbstract(private val purchase: Int) : FlowLogic() { + @Suspendable + override fun call(): Int { + + // This object is difficult to serialize with Kryo + val difficultToSerialize: DifficultToSerialize.BrokenMapAbstract = DifficultToSerialize.BrokenMapAbstractImpl() + difficultToSerialize.putAll(mapOf("foo" to purchase)) + + // Force a checkpoint + sleep(Duration.ofSeconds(0)) + + // Return value from deserialized object + return difficultToSerialize["foo"] ?: 0 + } + } + + @StartableByRPC + class TestFlowWithDifficultToSerializeLocalVariableAsFinal(private val purchase: Int) : FlowLogic() { + @Suspendable + override fun call(): Int { + + // This object is difficult to serialize with Kryo + val difficultToSerialize: DifficultToSerialize.BrokenMapFinal = DifficultToSerialize.BrokenMapFinal() + difficultToSerialize.putAll(mapOf("foo" to purchase)) + + // Force a checkpoint + sleep(Duration.ofSeconds(0)) + + // Return value from deserialized object + return difficultToSerialize["foo"] ?: 0 + } + } + + @StartableByRPC + class TestFlowWithDifficultToSerializeLocalVariableAsInterface(private val purchase: Int) : FlowLogic() { + @Suspendable + override fun call(): Int { + + // This object is difficult to serialize with Kryo + val difficultToSerialize: DifficultToSerialize.BrokenMapInterface = DifficultToSerialize.BrokenMapInterfaceImpl() + difficultToSerialize.putAll(mapOf("foo" to purchase)) + + // Force a checkpoint + sleep(Duration.ofSeconds(0)) + + // Return value from deserialized object + return difficultToSerialize["foo"] ?: 0 + } + } + + @StartableByRPC + class TestFlowWithDifficultToSerializeLocalVariable(private val purchase: Int) : FlowLogic() { + @Suspendable + override fun call(): Int { + + // This object is difficult to serialize with Kryo + val difficultToSerialize: DifficultToSerialize.BrokenMapClass = DifficultToSerialize.BrokenMapClass() + difficultToSerialize.putAll(mapOf("foo" to purchase)) + + // Force a checkpoint + sleep(Duration.ofSeconds(0)) + + // Return value from deserialized object + return difficultToSerialize["foo"] ?: 0 + } + } + + @StartableByRPC + class TestFlowCheckingReferencesWork(private val reference: DifficultToSerialize.BrokenMapClass) : + FlowLogic>() { + + private val referenceField = reference + @Suspendable + override fun call(): DifficultToSerialize.BrokenMapClass { + + val ref = referenceField + + // Force a checkpoint + sleep(Duration.ofSeconds(0)) + + // Check all objects refer to same object + Assertions.assertThat(reference).isSameAs(referenceField) + Assertions.assertThat(referenceField).isSameAs(ref) + + // Return deserialized object + return ref + } + } + + + @StartableByRPC + class TestFlowCheckingPublicKeySerializer : + FlowLogic() { + + @Suspendable + override fun call(): PublicKey { + val ref = ourIdentity.owningKey + + // Force a checkpoint + sleep(Duration.ofSeconds(0)) + + // Return deserialized object + return ref + } + } + + // Custom serializers + + @Suppress("unused") + class TestInterfaceSerializer : + CheckpointCustomSerializer, HashMap> { + + override fun toProxy(obj: DifficultToSerialize.BrokenMapInterface): HashMap { + val proxy = HashMap() + return obj.toMap(proxy) + } + override fun fromProxy(proxy: HashMap): DifficultToSerialize.BrokenMapInterface { + return DifficultToSerialize.BrokenMapInterfaceImpl() + .also { it.putAll(proxy) } + } + } + + @Suppress("unused") + class TestClassSerializer : + CheckpointCustomSerializer, HashMap> { + + override fun toProxy(obj: DifficultToSerialize.BrokenMapClass): HashMap { + val proxy = HashMap() + return obj.toMap(proxy) + } + override fun fromProxy(proxy: HashMap): DifficultToSerialize.BrokenMapClass { + return DifficultToSerialize.BrokenMapClass() + .also { it.putAll(proxy) } + } + } + + @Suppress("unused") + class TestAbstractClassSerializer : + CheckpointCustomSerializer, HashMap> { + + override fun toProxy(obj: DifficultToSerialize.BrokenMapAbstract): HashMap { + val proxy = HashMap() + return obj.toMap(proxy) + } + override fun fromProxy(proxy: HashMap): DifficultToSerialize.BrokenMapAbstract { + return DifficultToSerialize.BrokenMapAbstractImpl() + .also { it.putAll(proxy) } + } + } + + @Suppress("unused") + class TestFinalClassSerializer : + CheckpointCustomSerializer, HashMap> { + + override fun toProxy(obj: DifficultToSerialize.BrokenMapFinal): HashMap { + val proxy = HashMap() + return obj.toMap(proxy) + } + override fun fromProxy(proxy: HashMap): DifficultToSerialize.BrokenMapFinal { + return DifficultToSerialize.BrokenMapFinal() + .also { it.putAll(proxy) } + } + } + + @Suppress("unused") + class BrokenPublicKeySerializer : + CheckpointCustomSerializer { + override fun toProxy(obj: PublicKey): String { + throw FlowException("Broken on purpose") + } + + override fun fromProxy(proxy: String): PublicKey { + throw FlowException("Broken on purpose") + } + } + + @Suppress("unused") + class BrokenEdDSAPublicKeySerializer : + CheckpointCustomSerializer { + override fun toProxy(obj: EdDSAPublicKey): String { + throw FlowException("Broken on purpose") + } + + override fun fromProxy(proxy: String): EdDSAPublicKey { + throw FlowException("Broken on purpose") + } + } + +} diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index 77d95abacd..0c2552c37b 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -644,8 +644,8 @@ open class Node(configuration: NodeConfiguration, storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader), checkpointSerializer = KryoCheckpointSerializer, - checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader) - ) + checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader).withCheckpointCustomSerializers(cordappLoader.cordapps.flatMap { it.checkpointCustomSerializers }) + ) } /** Starts a blocking event loop for message dispatch. */ diff --git a/node/src/main/kotlin/net/corda/node/internal/cordapp/JarScanningCordappLoader.kt b/node/src/main/kotlin/net/corda/node/internal/cordapp/JarScanningCordappLoader.kt index 97a5672846..bb2fce1a58 100644 --- a/node/src/main/kotlin/net/corda/node/internal/cordapp/JarScanningCordappLoader.kt +++ b/node/src/main/kotlin/net/corda/node/internal/cordapp/JarScanningCordappLoader.kt @@ -18,6 +18,7 @@ import net.corda.core.internal.notary.NotaryService import net.corda.core.internal.notary.SinglePartyNotaryService import net.corda.core.node.services.CordaService import net.corda.core.schemas.MappedSchema +import net.corda.core.serialization.CheckpointCustomSerializer import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.SerializeAsToken @@ -185,6 +186,7 @@ class JarScanningCordappLoader private constructor(private val cordappJarPaths: findServices(this), findWhitelists(url), findSerializers(this), + findCheckpointSerializers(this), findCustomSchemas(this), findAllFlows(this), url.url, @@ -334,6 +336,10 @@ class JarScanningCordappLoader private constructor(private val cordappJarPaths: return scanResult.getClassesImplementingWithClassVersionCheck(SerializationCustomSerializer::class) } + private fun findCheckpointSerializers(scanResult: RestrictedScanResult): List> { + return scanResult.getClassesImplementingWithClassVersionCheck(CheckpointCustomSerializer::class) + } + private fun findCustomSchemas(scanResult: RestrictedScanResult): Set { return scanResult.getClassesWithSuperclass(MappedSchema::class).instances().toSet() } diff --git a/node/src/main/kotlin/net/corda/node/internal/cordapp/VirtualCordapps.kt b/node/src/main/kotlin/net/corda/node/internal/cordapp/VirtualCordapps.kt index 3f9e3b85f9..5ad5add351 100644 --- a/node/src/main/kotlin/net/corda/node/internal/cordapp/VirtualCordapps.kt +++ b/node/src/main/kotlin/net/corda/node/internal/cordapp/VirtualCordapps.kt @@ -32,6 +32,7 @@ internal object VirtualCordapp { services = listOf(), serializationWhitelists = listOf(), serializationCustomSerializers = listOf(), + checkpointCustomSerializers = listOf(), customSchemas = setOf(), info = Cordapp.Info.Default("corda-core", versionInfo.vendor, versionInfo.releaseVersion, "Open Source (Apache 2)"), allFlows = listOf(), @@ -55,6 +56,7 @@ internal object VirtualCordapp { services = listOf(), serializationWhitelists = listOf(), serializationCustomSerializers = listOf(), + checkpointCustomSerializers = listOf(), customSchemas = setOf(NodeNotarySchemaV1), info = Cordapp.Info.Default("corda-notary", versionInfo.vendor, versionInfo.releaseVersion, "Open Source (Apache 2)"), allFlows = listOf(), @@ -78,6 +80,7 @@ internal object VirtualCordapp { services = listOf(), serializationWhitelists = listOf(), serializationCustomSerializers = listOf(), + checkpointCustomSerializers = listOf(), customSchemas = setOf(RaftNotarySchemaV1), info = Cordapp.Info.Default("corda-notary-raft", versionInfo.vendor, versionInfo.releaseVersion, "Open Source (Apache 2)"), allFlows = listOf(), @@ -101,6 +104,7 @@ internal object VirtualCordapp { services = listOf(), serializationWhitelists = listOf(), serializationCustomSerializers = listOf(), + checkpointCustomSerializers = listOf(), customSchemas = setOf(BFTSmartNotarySchemaV1), info = Cordapp.Info.Default("corda-notary-bft-smart", versionInfo.vendor, versionInfo.releaseVersion, "Open Source (Apache 2)"), allFlows = listOf(), diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/CheckpointSerializationScheme.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/CheckpointSerializationScheme.kt index b6c43ddc6d..f037e2dfbb 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/CheckpointSerializationScheme.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/CheckpointSerializationScheme.kt @@ -1,6 +1,7 @@ package net.corda.serialization.internal import net.corda.core.KeepForDJVM +import net.corda.core.serialization.CheckpointCustomSerializer import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.EncodingWhitelist import net.corda.core.serialization.SerializationEncoding @@ -13,7 +14,8 @@ data class CheckpointSerializationContextImpl @JvmOverloads constructor( override val properties: Map, override val objectReferencesEnabled: Boolean, override val encoding: SerializationEncoding?, - override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) : CheckpointSerializationContext { + override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist, + override val checkpointCustomSerializers: Iterable> = emptyList()) : CheckpointSerializationContext { override fun withProperty(property: Any, value: Any): CheckpointSerializationContext { return copy(properties = properties + (property to value)) } @@ -34,4 +36,6 @@ data class CheckpointSerializationContextImpl @JvmOverloads constructor( override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding) override fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist) = copy(encodingWhitelist = encodingWhitelist) + override fun withCheckpointCustomSerializers(checkpointCustomSerializers : Iterable>) + = copy(checkpointCustomSerializers = checkpointCustomSerializers) } \ No newline at end of file diff --git a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt index 61bf91aac9..116016b991 100644 --- a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt +++ b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt @@ -2,6 +2,7 @@ package net.corda.coretesting.internal import net.corda.nodeapi.internal.rpc.client.AMQPClientSerializationScheme import net.corda.core.internal.createInstancesOfClassesImplementing +import net.corda.core.serialization.CheckpointCustomSerializer import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.internal.SerializationEnvironment @@ -25,8 +26,11 @@ fun createTestSerializationEnv(): SerializationEnvironment { } fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironment { + var customCheckpointSerializers: Set> = emptySet() val (clientSerializationScheme, serverSerializationScheme) = if (classLoader != null) { val customSerializers = createInstancesOfClassesImplementing(classLoader, SerializationCustomSerializer::class.java) + customCheckpointSerializers = createInstancesOfClassesImplementing(classLoader, CheckpointCustomSerializer::class.java) + val serializationWhitelists = ServiceLoader.load(SerializationWhitelist::class.java, classLoader).toSet() Pair(AMQPClientSerializationScheme(customSerializers, serializationWhitelists), @@ -44,7 +48,7 @@ fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironm AMQP_RPC_SERVER_CONTEXT, AMQP_RPC_CLIENT_CONTEXT, AMQP_STORAGE_CONTEXT, - KRYO_CHECKPOINT_CONTEXT, + KRYO_CHECKPOINT_CONTEXT.withCheckpointCustomSerializers(customCheckpointSerializers), KryoCheckpointSerializer ) }