diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomSerializerCheckpointAdaptor.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomSerializerCheckpointAdaptor.kt new file mode 100644 index 0000000000..36d8353b0b --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomSerializerCheckpointAdaptor.kt @@ -0,0 +1,54 @@ +package net.corda.nodeapi.internal.serialization.kryo + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.Serializer +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import net.corda.core.serialization.SerializationCustomSerializer +import net.corda.serialization.internal.amqp.CORDAPP_TYPE +import net.corda.serialization.internal.amqp.PROXY_TYPE +import java.lang.reflect.Type +import kotlin.reflect.jvm.javaType +import kotlin.reflect.jvm.jvmErasure + +class CustomSerializerCheckpointAdaptor(private val userSerializer : SerializationCustomSerializer) : Serializer() { + + val type: Type + val proxyType: Type + + init { + val types = userSerializer::class.supertypes.filter { it.jvmErasure == SerializationCustomSerializer::class } + .flatMap { it.arguments } + .map { it.type!!.javaType } + if (types.size != 2) { + throw CustomSerializerCheckpointAdaptorException("Unable to determine serializer parent types") + } + type = types[CORDAPP_TYPE] + proxyType = types[PROXY_TYPE] + } + + override fun write(kryo: Kryo, output: Output, obj: OBJ) { + try { + kryo.writeClassAndObject(output, userSerializer.toProxy(obj)) + } catch (e: Exception) { + throw CustomSerializerCheckpointAdaptorException("Failed converting ${type.typeName} to ${proxyType.typeName}", e) + } + } + + override fun read(kryo: Kryo, input: Input, type: Class): OBJ { + try { + @Suppress("UNCHECKED_CAST") + return userSerializer.fromProxy(kryo.readClassAndObject(input) as PROXY) + } catch (e: Exception) { + throw CustomSerializerCheckpointAdaptorException("Failed converting ${proxyType.typeName} to ${this.type.typeName}", e) + } + } +} + +class CustomSerializerCheckpointAdaptorException : java.lang.Exception { + constructor() : super() + constructor(message: String?) : super(message) + constructor(message: String?, cause: Throwable?) : super(message, cause) + constructor(cause: Throwable?) : super(cause) + constructor(message: String?, cause: Throwable?, enableSuppression: Boolean, writableStackTrace: Boolean) : super(message, cause, enableSuppression, writableStackTrace) +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt index 6a73119ce6..646a464f24 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt @@ -11,6 +11,7 @@ import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.serializers.ClosureSerializer import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.ClassWhitelist +import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.internal.CheckpointSerializationContext @@ -41,6 +42,7 @@ private object AutoCloseableSerialisationDetector : Serializer() object KryoCheckpointSerializer : CheckpointSerializer { private val kryoPoolsForContexts = ConcurrentHashMap, KryoPool>() + private var cordappSerializers: List> = listOf() private fun getPool(context: CheckpointSerializationContext): KryoPool { return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { @@ -51,6 +53,19 @@ object KryoCheckpointSerializer : CheckpointSerializer { val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true } serializer.kryo.apply { field.set(this, classResolver) + + for (customSerializer in cordappSerializers) { + + val typeName = customSerializer.type.typeName.substringBefore('<') + val clazz = context.deserializationClassLoader.loadClass(typeName) + + if (clazz.isInterface){ + addDefaultSerializer(clazz, customSerializer) + } else { + register(clazz, customSerializer) + } + } + // don't allow overriding the public key serializer for checkpointing DefaultKryoCustomizer.customize(this) addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector) @@ -120,6 +135,10 @@ object KryoCheckpointSerializer : CheckpointSerializer { }) } } + + fun addCordappSerializers(customSerializers: Collection>) { + cordappSerializers = customSerializers.map { CustomSerializerCheckpointAdaptor(it) } + } } val KRYO_CHECKPOINT_CONTEXT = CheckpointSerializationContextImpl( diff --git a/node/build.gradle b/node/build.gradle index 58f7a5498e..f76249bba8 100644 --- a/node/build.gradle +++ b/node/build.gradle @@ -191,7 +191,8 @@ dependencies { // Integration test helpers integrationTestCompile "junit:junit:$junit_version" integrationTestCompile "org.assertj:assertj-core:${assertj_version}" - + integrationTestCompile "com.github.andrewoma.dexx:kollection:0.7" + // BFT-Smart dependencies compile 'com.github.bft-smart:library:master-v1.1-beta-g6215ec8-87' compile 'commons-codec:commons-codec:1.13' diff --git a/node/src/integration-test/kotlin/net/corda/node/MockNetworkCustomSerializerCheckpointTest.kt b/node/src/integration-test/kotlin/net/corda/node/MockNetworkCustomSerializerCheckpointTest.kt new file mode 100644 index 0000000000..e21909db12 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/MockNetworkCustomSerializerCheckpointTest.kt @@ -0,0 +1,68 @@ +package net.corda.node + +import co.paralleluniverse.fibers.Suspendable +import com.github.andrewoma.dexx.kollection.ImmutableMap +import com.github.andrewoma.dexx.kollection.immutableMapOf +import com.github.andrewoma.dexx.kollection.toImmutableMap +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StartableByRPC +import net.corda.core.serialization.SerializationCustomSerializer +import net.corda.testing.node.MockNetwork +import net.corda.testing.node.MockNetworkParameters +import net.corda.testing.node.internal.enclosedCordapp +import org.assertj.core.api.Assertions +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.time.Duration + +class MockNetworkCustomSerializerCheckpointTest{ + private lateinit var mockNetwork: MockNetwork + + @Before + fun setup() { + mockNetwork = MockNetwork(MockNetworkParameters(cordappsForAllNodes = listOf(enclosedCordapp()))) + } + + @After + fun shutdown() { + mockNetwork.stopNodes() + } + + @Test(timeout = 300_000) + fun `flow suspend with custom kryo serializer`() { + val node = mockNetwork.createPartyNode() + val expected = 5 + val actual = node.startFlow(TestFlow(5)).get() + + Assertions.assertThat(actual).isEqualTo(expected) + } + + @StartableByRPC + class TestFlow(private val purchase: Int) : FlowLogic() { + @Suspendable + override fun call(): Int { + + // This object is difficult to serialize with Kryo + val difficultToSerialize: ImmutableMap = immutableMapOf("foo" to purchase) + + // Force a checkpoint + sleep(Duration.ofSeconds(10), maySkipCheckpoint = false) + + // Return value from deserialized object + return difficultToSerialize["foo"] ?: 0 + } + } + + @Suppress("unused") + class TestSerializer : SerializationCustomSerializer, HashMap> { + override fun toProxy(obj: ImmutableMap): HashMap { + val proxy = HashMap() + return obj.toMap(proxy) + } + + override fun fromProxy(proxy: HashMap): ImmutableMap { + return proxy.toImmutableMap() + } + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index 77d95abacd..0d1c94f37f 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -643,7 +643,7 @@ open class Node(configuration: NodeConfiguration, rpcClientContext = if (configuration.shouldInitCrashShell()) AMQP_RPC_CLIENT_CONTEXT.withClassLoader(classloader) else null, //even Shell embeded in the node connects via RPC to the node storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader), - checkpointSerializer = KryoCheckpointSerializer, + checkpointSerializer = KryoCheckpointSerializer.also { checkpointSerializer -> checkpointSerializer.addCordappSerializers(cordappLoader.cordapps.flatMap { it.serializationCustomSerializers }) }, checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader) ) } diff --git a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt index 61bf91aac9..8910a44d83 100644 --- a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt +++ b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt @@ -25,8 +25,9 @@ fun createTestSerializationEnv(): SerializationEnvironment { } fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironment { + var customSerializers: Set> = emptySet() val (clientSerializationScheme, serverSerializationScheme) = if (classLoader != null) { - val customSerializers = createInstancesOfClassesImplementing(classLoader, SerializationCustomSerializer::class.java) + customSerializers = createInstancesOfClassesImplementing(classLoader, SerializationCustomSerializer::class.java) val serializationWhitelists = ServiceLoader.load(SerializationWhitelist::class.java, classLoader).toSet() Pair(AMQPClientSerializationScheme(customSerializers, serializationWhitelists), @@ -45,7 +46,7 @@ fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironm AMQP_RPC_CLIENT_CONTEXT, AMQP_STORAGE_CONTEXT, KRYO_CHECKPOINT_CONTEXT, - KryoCheckpointSerializer + KryoCheckpointSerializer.also { it.addCordappSerializers(customSerializers) } ) }