From f9ccb88feaedcc749196a36c8277bee5e3e1f835 Mon Sep 17 00:00:00 2001
From: Christian Sailer <christian.sailer@r3.com>
Date: Wed, 18 Mar 2020 09:37:56 +0000
Subject: [PATCH] ENT-4494 Harmonize Kryo serialalization  (#6069)

* Harmonize Kryo serialalization (Custom serializer for iterators/collections)

* Fix package name

* Revert checkpoint compression change.

* Clean imports
---
 .../kryo/CustomIteratorSerializers.kt         | 121 ++++++++++++++++
 .../kryo/DefaultKryoCustomizer.kt             |  24 +++-
 .../internal/serialization/kryo/Kryo.kt       |  16 ++-
 .../kryo/KryoCheckpointSerializer.kt          |  13 +-
 .../serialization/kryo/KryoCheckpointTest.kt  | 132 ++++++++++++++++++
 5 files changed, 293 insertions(+), 13 deletions(-)
 create mode 100644 node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomIteratorSerializers.kt
 create mode 100644 node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointTest.kt

diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomIteratorSerializers.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomIteratorSerializers.kt
new file mode 100644
index 0000000000..ff50a1c137
--- /dev/null
+++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomIteratorSerializers.kt
@@ -0,0 +1,121 @@
+package net.corda.nodeapi.internal.serialization.kryo
+
+import com.esotericsoftware.kryo.Kryo
+import com.esotericsoftware.kryo.Serializer
+import com.esotericsoftware.kryo.io.Input
+import com.esotericsoftware.kryo.io.Output
+import java.lang.reflect.Constructor
+import java.lang.reflect.Field
+import java.util.LinkedList
+
+/**
+ * The [LinkedHashMap] and [LinkedHashSet] have a problem with the default Quasar/Kryo serialisation
+ * in that serialising an iterator (and subsequent [LinkedHashMap.Entry]) over a sufficiently large
+ * data set can lead to a stack overflow (because the object map is traversed recursively).
+ *
+ * We've added our own custom serializer in order to ensure that the iterator is correctly deserialized.
+ */
+internal object LinkedHashMapIteratorSerializer : Serializer<Iterator<*>>() {
+    private val DUMMY_MAP = linkedMapOf(1L to 1)
+    private val outerMapField: Field = getIterator()::class.java.superclass.getDeclaredField("this$0").apply { isAccessible = true }
+    private val currentField: Field = getIterator()::class.java.superclass.getDeclaredField("current").apply { isAccessible = true }
+
+    private val KEY_ITERATOR_CLASS: Class<MutableIterator<Long>> = DUMMY_MAP.keys.iterator().javaClass
+    private val VALUE_ITERATOR_CLASS: Class<MutableIterator<Int>> = DUMMY_MAP.values.iterator().javaClass
+    private val MAP_ITERATOR_CLASS: Class<MutableIterator<MutableMap.MutableEntry<Long, Int>>> = DUMMY_MAP.iterator().javaClass
+
+    fun getIterator(): Any = DUMMY_MAP.iterator()
+
+    override fun write(kryo: Kryo, output: Output, obj: Iterator<*>) {
+        val current: Map.Entry<*, *>? = currentField.get(obj) as Map.Entry<*, *>?
+        kryo.writeClassAndObject(output, outerMapField.get(obj))
+        kryo.writeClassAndObject(output, current)
+    }
+
+    override fun read(kryo: Kryo, input: Input, type: Class<Iterator<*>>): Iterator<*> {
+        val outerMap = kryo.readClassAndObject(input) as Map<*, *>
+        return when (type) {
+            KEY_ITERATOR_CLASS -> {
+                val current = (kryo.readClassAndObject(input) as? Map.Entry<*, *>)?.key
+                outerMap.keys.iterator().returnToIteratorLocation(current)
+            }
+            VALUE_ITERATOR_CLASS -> {
+                val current = (kryo.readClassAndObject(input) as? Map.Entry<*, *>)?.value
+                outerMap.values.iterator().returnToIteratorLocation(current)
+            }
+            MAP_ITERATOR_CLASS -> {
+                val current = (kryo.readClassAndObject(input) as? Map.Entry<*, *>)
+                outerMap.iterator().returnToIteratorLocation(current)
+            }
+            else -> throw IllegalStateException("Invalid type")
+        }
+    }
+
+    private fun Iterator<*>.returnToIteratorLocation(current: Any?) : Iterator<*> {
+        while (this.hasNext()) {
+            val key = this.next()
+            @Suppress("SuspiciousEqualsCombination")
+            if (current == null || key === current || key == current) {
+                break
+            }
+        }
+        return this
+    }
+}
+
+/**
+ * The [LinkedHashMap] and [LinkedHashSet] have a problem with the default Quasar/Kryo serialisation
+ * in that serialising an iterator (and subsequent [LinkedHashMap.Entry]) over a sufficiently large
+ * data set can lead to a stack overflow (because the object map is traversed recursively).
+ *
+ * We've added our own custom serializer in order to ensure that only the key/value are recorded.
+ * The rest of the list isn't required at this scope.
+ */
+object LinkedHashMapEntrySerializer : Serializer<Map.Entry<*, *>>() {
+    // Create a dummy map so that we can get the LinkedHashMap$Entry from it
+    // The element type of the map doesn't matter.  The entry is all we want
+    private val DUMMY_MAP = linkedMapOf(1L to 1)
+    fun getEntry(): Any = DUMMY_MAP.entries.first()
+    private val constr: Constructor<*> = getEntry()::class.java.declaredConstructors.single().apply { isAccessible = true }
+
+    /**
+     * Kryo would end up serialising "this" entry, then serialise "this.after" recursively, leading to a very large stack.
+     * we'll skip that and just write out the key/value
+     */
+    override fun write(kryo: Kryo, output: Output, obj: Map.Entry<*, *>) {
+        val e: Map.Entry<*, *> = obj
+        kryo.writeClassAndObject(output, e.key)
+        kryo.writeClassAndObject(output, e.value)
+    }
+
+    override fun read(kryo: Kryo, input: Input, type: Class<Map.Entry<*, *>>): Map.Entry<*, *> {
+        val key = kryo.readClassAndObject(input)
+        val value = kryo.readClassAndObject(input)
+        return constr.newInstance(0, key, value, null) as Map.Entry<*, *>
+    }
+}
+
+/**
+ * Also, add a [ListIterator] serializer to avoid more linked list issues.
+*/
+object LinkedListItrSerializer : Serializer<ListIterator<Any>>() {
+    // Create a dummy list so that we can get the ListItr from it
+    // The element type of the list doesn't matter.  The iterator is all we want
+    private val DUMMY_LIST = LinkedList<Long>(listOf(1))
+    fun getListItr(): Any  = DUMMY_LIST.listIterator()
+
+    private val outerListField: Field = getListItr()::class.java.getDeclaredField("this$0").apply { isAccessible = true }
+
+    override fun write(kryo: Kryo, output: Output, obj: ListIterator<Any>) {
+        kryo.writeClassAndObject(output, outerListField.get(obj))
+        output.writeInt(obj.nextIndex())
+    }
+
+    override fun read(kryo: Kryo, input: Input, type: Class<ListIterator<Any>>): ListIterator<Any> {
+        val list = kryo.readClassAndObject(input) as LinkedList<*>
+        val index = input.readInt()
+        return list.listIterator(index)
+    }
+}
+
+
diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/DefaultKryoCustomizer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/DefaultKryoCustomizer.kt
index 615f9a74f5..cf1728c34f 100644
--- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/DefaultKryoCustomizer.kt
+++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/DefaultKryoCustomizer.kt
@@ -10,7 +10,11 @@ import com.esotericsoftware.kryo.serializers.FieldSerializer
 import de.javakaffee.kryoserializers.ArraysAsListSerializer
 import de.javakaffee.kryoserializers.BitSetSerializer
 import de.javakaffee.kryoserializers.UnmodifiableCollectionsSerializer
-import de.javakaffee.kryoserializers.guava.*
+import de.javakaffee.kryoserializers.guava.ImmutableListSerializer
+import de.javakaffee.kryoserializers.guava.ImmutableMapSerializer
+import de.javakaffee.kryoserializers.guava.ImmutableMultimapSerializer
+import de.javakaffee.kryoserializers.guava.ImmutableSetSerializer
+import de.javakaffee.kryoserializers.guava.ImmutableSortedSetSerializer
 import net.corda.core.contracts.ContractAttachment
 import net.corda.core.contracts.ContractClassName
 import net.corda.core.contracts.PrivacySalt
@@ -24,7 +28,11 @@ import net.corda.core.serialization.MissingAttachmentsException
 import net.corda.core.serialization.SerializationWhitelist
 import net.corda.core.serialization.SerializeAsToken
 import net.corda.core.serialization.SerializedBytes
-import net.corda.core.transactions.*
+import net.corda.core.transactions.ContractUpgradeFilteredTransaction
+import net.corda.core.transactions.ContractUpgradeWireTransaction
+import net.corda.core.transactions.NotaryChangeWireTransaction
+import net.corda.core.transactions.SignedTransaction
+import net.corda.core.transactions.WireTransaction
 import net.corda.core.utilities.NonEmptySet
 import net.corda.core.utilities.toNonEmptySet
 import net.corda.serialization.internal.DefaultWhitelist
@@ -51,8 +59,9 @@ import java.security.PrivateKey
 import java.security.PublicKey
 import java.security.cert.CertPath
 import java.security.cert.X509Certificate
-import java.util.*
-import kotlin.collections.ArrayList
+import java.util.Arrays
+import java.util.BitSet
+import java.util.ServiceLoader
 
 object DefaultKryoCustomizer {
     private val serializationWhitelists: List<SerializationWhitelist> by lazy {
@@ -70,7 +79,8 @@ object DefaultKryoCustomizer {
             instantiatorStrategy = CustomInstantiatorStrategy()
 
             // Required for HashCheckingStream (de)serialization.
-            // Note that return type should be specifically set to InputStream, otherwise it may not work, i.e. val aStream : InputStream = HashCheckingStream(...).
+            // Note that return type should be specifically set to InputStream, otherwise it may not work,
+            // i.e. val aStream : InputStream = HashCheckingStream(...).
             addDefaultSerializer(InputStream::class.java, InputStreamSerializer)
             addDefaultSerializer(SerializeAsToken::class.java, SerializeAsTokenSerializer<SerializeAsToken>())
             addDefaultSerializer(Logger::class.java, LoggerSerializer)
@@ -79,8 +89,10 @@ object DefaultKryoCustomizer {
             // WARNING: reordering the registrations here will cause a change in the serialized form, since classes
             // with custom serializers get written as registration ids. This will break backwards-compatibility.
             // Please add any new registrations to the end.
-            // TODO: re-organise registrations into logical groups before v1.0
 
+            addDefaultSerializer(LinkedHashMapIteratorSerializer.getIterator()::class.java.superclass, LinkedHashMapIteratorSerializer)
+            register(LinkedHashMapEntrySerializer.getEntry()::class.java, LinkedHashMapEntrySerializer)
+            register(LinkedListItrSerializer.getListItr()::class.java, LinkedListItrSerializer)
             register(Arrays.asList("").javaClass, ArraysAsListSerializer())
             register(LazyMappedList::class.java, LazyMappedListSerializer)
             register(SignedTransaction::class.java, SignedTransactionSerializer)
diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/Kryo.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/Kryo.kt
index e5de8a1341..929fa63a8e 100644
--- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/Kryo.kt
+++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/Kryo.kt
@@ -18,6 +18,8 @@ import net.corda.core.serialization.SerializeAsTokenContext
 import net.corda.core.serialization.SerializedBytes
 import net.corda.core.transactions.*
 import net.corda.core.utilities.OpaqueBytes
+import net.corda.serialization.internal.checkUseCase
+import net.corda.core.utilities.SgxSupport
 import net.corda.serialization.internal.serializationContextKey
 import org.slf4j.Logger
 import org.slf4j.LoggerFactory
@@ -67,13 +69,17 @@ object SerializedBytesSerializer : Serializer<SerializedBytes<Any>>() {
  * set via the constructor and the class is immutable.
  */
 class ImmutableClassSerializer<T : Any>(val klass: KClass<T>) : Serializer<T>() {
-    val props = klass.memberProperties.sortedBy { it.name }
-    val propsByName = props.associateBy { it.name }
-    val constructor = klass.primaryConstructor!!
+    val props by lazy { klass.memberProperties.sortedBy { it.name } }
+    val propsByName by lazy { props.associateBy { it.name } }
+    val constructor by lazy { klass.primaryConstructor!! }
 
     init {
-        props.forEach {
-            require(it !is KMutableProperty<*>) { "$it mutable property of class: ${klass} is unsupported" }
+        // Verify that this class is immutable (all properties are final).
+        // We disable this check inside SGX as the reflection blows up.
+        if (!SgxSupport.isInsideEnclave) {
+            props.forEach {
+                require(it !is KMutableProperty<*>) { "$it mutable property of class: ${klass} is unsupported" }
+            }
         }
     }
 
diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt
index 7dddc2a65f..6a73119ce6 100644
--- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt
+++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointSerializer.kt
@@ -10,11 +10,20 @@ import com.esotericsoftware.kryo.io.Output
 import com.esotericsoftware.kryo.pool.KryoPool
 import com.esotericsoftware.kryo.serializers.ClosureSerializer
 import net.corda.core.internal.uncheckedCast
-import net.corda.core.serialization.*
+import net.corda.core.serialization.ClassWhitelist
+import net.corda.core.serialization.SerializationDefaults
+import net.corda.core.serialization.SerializedBytes
 import net.corda.core.serialization.internal.CheckpointSerializationContext
 import net.corda.core.serialization.internal.CheckpointSerializer
 import net.corda.core.utilities.ByteSequence
-import net.corda.serialization.internal.*
+import net.corda.serialization.internal.AlwaysAcceptEncodingWhitelist
+import net.corda.serialization.internal.ByteBufferInputStream
+import net.corda.serialization.internal.CheckpointSerializationContextImpl
+import net.corda.serialization.internal.CordaSerializationEncoding
+import net.corda.serialization.internal.CordaSerializationMagic
+import net.corda.serialization.internal.QuasarWhitelist
+import net.corda.serialization.internal.SectionId
+import net.corda.serialization.internal.encodingNotPermittedFormat
 import java.util.concurrent.ConcurrentHashMap
 
 val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0))
diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointTest.kt
new file mode 100644
index 0000000000..70ab07f54a
--- /dev/null
+++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointTest.kt
@@ -0,0 +1,132 @@
+package net.corda.nodeapi.internal.serialization.kryo
+
+import org.junit.Test
+import org.junit.jupiter.api.assertDoesNotThrow
+import java.util.LinkedList
+import kotlin.test.assertEquals
+
+class KryoCheckpointTest {
+
+    private val testSize = 1000L
+
+    /**
+     * This test just ensures that the checkpoints still work in light of [LinkedHashMapEntrySerializer].
+     */
+    @Test(timeout=300_000)
+	fun `linked hash map can checkpoint without error`() {
+        var lastKey = ""
+        val dummyMap = linkedMapOf<String, Long>()
+        for (i in 0..testSize) {
+            dummyMap[i.toString()] = i
+        }
+        var it = dummyMap.iterator()
+        while (it.hasNext()) {
+            lastKey = it.next().key
+            val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
+            it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
+        }
+        assertEquals(testSize.toString(), lastKey)
+    }
+
+    @Test(timeout=300_000)
+    fun `empty linked hash map can checkpoint without error`() {
+        val dummyMap = linkedMapOf<String, Long>()
+        val it = dummyMap.iterator()
+        val itKeys = dummyMap.keys.iterator()
+        val itValues = dummyMap.values.iterator()
+        val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
+        val bytesKeys = KryoCheckpointSerializer.serialize(itKeys, KRYO_CHECKPOINT_CONTEXT)
+        val bytesValues = KryoCheckpointSerializer.serialize(itValues, KRYO_CHECKPOINT_CONTEXT)
+        assertDoesNotThrow {
+            KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
+            KryoCheckpointSerializer.deserialize(bytesKeys, itKeys.javaClass, KRYO_CHECKPOINT_CONTEXT)
+            KryoCheckpointSerializer.deserialize(bytesValues, itValues.javaClass, KRYO_CHECKPOINT_CONTEXT)
+        }
+    }
+
+    @Test(timeout=300_000)
+    fun `linked hash map with null values can checkpoint without error`() {
+        val dummyMap = linkedMapOf<String?, Long?>().apply {
+                put(null, null)
+        }
+        val it = dummyMap.iterator()
+        val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
+
+        val itKeys = dummyMap.keys.iterator()
+        val bytesKeys = KryoCheckpointSerializer.serialize(itKeys, KRYO_CHECKPOINT_CONTEXT)
+
+        val itValues = dummyMap.values.iterator()
+        val bytesValues = KryoCheckpointSerializer.serialize(itValues, KRYO_CHECKPOINT_CONTEXT)
+
+        assertDoesNotThrow {
+            KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
+            KryoCheckpointSerializer.deserialize(bytesKeys, itKeys.javaClass, KRYO_CHECKPOINT_CONTEXT)
+            KryoCheckpointSerializer.deserialize(bytesValues, itValues.javaClass, KRYO_CHECKPOINT_CONTEXT)
+        }
+    }
+
+    @Test(timeout=300_000)
+    fun `linked hash map keys can checkpoint without error`() {
+        var lastKey = ""
+        val dummyMap = linkedMapOf<String, Long>()
+        for (i in 0..testSize) {
+            dummyMap[i.toString()] = i
+        }
+        var it = dummyMap.keys.iterator()
+        while (it.hasNext()) {
+            lastKey = it.next()
+            val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
+            it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
+        }
+        assertEquals(testSize.toString(), lastKey)
+    }
+
+    @Test(timeout=300_000)
+	fun `linked hash map values can checkpoint without error`() {
+        var lastValue = 0L
+        val dummyMap = linkedMapOf<String, Long>()
+        for (i in 0..testSize) {
+            dummyMap[i.toString()] = i
+        }
+        var it = dummyMap.values.iterator()
+        while (it.hasNext()) {
+            lastValue = it.next()
+            val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
+            it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
+        }
+        assertEquals(testSize, lastValue)
+    }
+
+    /**
+     * This test just ensures that the checkpoints still work in light of [LinkedHashMapEntrySerializer].
+     */
+    @Test(timeout=300_000)
+	fun `linked hash set can checkpoint without error`() {
+        var result: Any = 0L
+        val dummySet = linkedSetOf<Any>().apply { addAll(0..testSize) }
+        var it = dummySet.iterator()
+        while (it.hasNext()) {
+            result = it.next()
+            val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
+            it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
+        }
+        assertEquals(testSize, result)
+    }
+
+    /**
+     * This test just ensures that the checkpoints still work in light of [LinkedListItrSerializer].
+     */
+    @Test(timeout=300_000)
+	fun `linked list can checkpoint without error`() {
+        var result: Any = 0L
+        val dummyList = LinkedList<Long>().apply { addAll(0..testSize) }
+
+        var it = dummyList.iterator()
+        while (it.hasNext()) {
+            result = it.next()
+            val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
+            it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
+        }
+        assertEquals(testSize, result)
+    }
+}