diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/DefaultKryoCustomizer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/DefaultKryoCustomizer.kt index cf1728c34f..7d9d9b62d5 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/DefaultKryoCustomizer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/DefaultKryoCustomizer.kt @@ -141,6 +141,10 @@ object DefaultKryoCustomizer { register(ContractUpgradeWireTransaction::class.java, ContractUpgradeWireTransactionSerializer) register(ContractUpgradeFilteredTransaction::class.java, ContractUpgradeFilteredTransactionSerializer) + addDefaultSerializer(Iterator::class.java) {kryo, type -> + IteratorSerializer(type, CompatibleFieldSerializer>(kryo, type).apply { setIgnoreSyntheticFields(false) }) + } + for (whitelistProvider in serializationWhitelists) { val types = whitelistProvider.whitelist require(types.toSet().size == types.size) { diff --git a/node/src/main/kotlin/net/corda/node/serialization/kryo/IteratorSerializer.kt b/node/src/main/kotlin/net/corda/node/serialization/kryo/IteratorSerializer.kt new file mode 100644 index 0000000000..382ae840c5 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/serialization/kryo/IteratorSerializer.kt @@ -0,0 +1,52 @@ +package net.corda.node.serialization.kryo + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.Serializer +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import java.lang.reflect.Field + +class IteratorSerializer(type: Class<*>, private val serializer: Serializer>) : Serializer>(false, false) { + + private val iterableReferenceField = findField(type, "this\$0")?.apply { isAccessible = true } + private val expectedModCountField = findField(type, "expectedModCount")?.apply { isAccessible = true } + private val iterableReferenceFieldType = iterableReferenceField?.type + private val modCountField = when (iterableReferenceFieldType) { + null -> null + else -> findField(iterableReferenceFieldType, "modCount")?.apply { isAccessible = true } + } + + override fun write(kryo: Kryo, output: Output, obj: Iterator<*>) { + serializer.write(kryo, output, obj) + } + + override fun read(kryo: Kryo, input: Input, type: Class>): Iterator<*> { + val iterator = serializer.read(kryo, input, type) + return fixIterator(iterator) + } + + private fun fixIterator(iterator: Iterator<*>) : Iterator<*> { + + // Set expectedModCount of iterator + val iterableInstance = iterableReferenceField?.get(iterator) ?: return iterator + val modCountValue = modCountField?.getInt(iterableInstance) ?: return iterator + expectedModCountField?.setInt(iterator, modCountValue) + + return iterator + } + + /** + * Find field in clazz or any superclass + */ + private fun findField(clazz: Class<*>, fieldName: String): Field? { + return clazz.declaredFields.firstOrNull { x -> x.name == fieldName } ?: when { + clazz.superclass != null -> { + // Look in superclasses + findField(clazz.superclass, fieldName) + } + else -> null // Not found + } + } +} + + diff --git a/node/src/test/kotlin/net/corda/node/serialization/kryo/ArrayListItrConcurrentModificationException.kt b/node/src/test/kotlin/net/corda/node/serialization/kryo/ArrayListItrConcurrentModificationException.kt new file mode 100644 index 0000000000..44a48c793d --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/serialization/kryo/ArrayListItrConcurrentModificationException.kt @@ -0,0 +1,122 @@ +package net.corda.node.serialization.kryo + +import com.nhaarman.mockito_kotlin.doReturn +import com.nhaarman.mockito_kotlin.whenever +import net.corda.core.serialization.EncodingWhitelist +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.checkpointDeserialize +import net.corda.core.serialization.internal.checkpointSerialize +import net.corda.serialization.internal.AllWhitelist +import net.corda.serialization.internal.CheckpointSerializationContextImpl +import net.corda.serialization.internal.CordaSerializationEncoding +import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule +import net.corda.testing.internal.rigorousMock +import org.assertj.core.api.Assertions.assertThat +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.junit.runners.Parameterized.Parameters +import java.util.* +import kotlin.collections.ArrayList +import kotlin.collections.HashMap +import kotlin.collections.HashSet +import kotlin.collections.LinkedHashMap +import kotlin.collections.LinkedHashSet + +@RunWith(Parameterized::class) +class ArrayListItrConcurrentModificationException(private val compression: CordaSerializationEncoding?) { + companion object { + @Parameters(name = "{0}") + @JvmStatic + fun compression() = arrayOf(null) + CordaSerializationEncoding.values() + } + + @get:Rule + val serializationRule = CheckpointSerializationEnvironmentRule(inheritable = true) + private lateinit var context: CheckpointSerializationContext + + @Before + fun setup() { + context = CheckpointSerializationContextImpl( + deserializationClassLoader = javaClass.classLoader, + whitelist = AllWhitelist, + properties = emptyMap(), + objectReferencesEnabled = true, + encoding = compression, + encodingWhitelist = rigorousMock().also { + if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression) + }) + } + + @Test(timeout=300_000) + fun `ArrayList iterator can checkpoint without error`() { + runTestWithCollection(ArrayList()) + } + + @Test(timeout=300_000) + fun `HashSet iterator can checkpoint without error`() { + runTestWithCollection(HashSet()) + } + + @Test(timeout=300_000) + fun `LinkedHashSet iterator can checkpoint without error`() { + runTestWithCollection(LinkedHashSet()) + } + + @Test(timeout=300_000) + fun `HashMap iterator can checkpoint without error`() { + runTestWithCollection(HashMap()) + } + + @Test(timeout=300_000) + fun `LinkedHashMap iterator can checkpoint without error`() { + runTestWithCollection(LinkedHashMap()) + } + + @Test(timeout=300_000) + fun `LinkedList iterator can checkpoint without error`() { + runTestWithCollection(LinkedList()) + } + + private data class TestCheckpoint(val list: C, val iterator: I) + + private fun runTestWithCollection(collection: MutableCollection) { + + for (i in 1..100) { + collection.add(i) + } + + val iterator = collection.iterator() + iterator.next() + + val checkpoint = TestCheckpoint(collection, iterator) + + val serializedBytes = checkpoint.checkpointSerialize(context) + val deserializedCheckpoint = serializedBytes.checkpointDeserialize(context) + + assertThat(deserializedCheckpoint.list).isEqualTo(collection) + assertThat(deserializedCheckpoint.iterator.next()).isEqualTo(2) + assertThat(deserializedCheckpoint.iterator.hasNext()).isTrue() + } + + private fun runTestWithCollection(collection: MutableMap) { + + for (i in 1..100) { + collection[i] = i + } + + val iterator = collection.iterator() + iterator.next() + + val checkpoint = TestCheckpoint(collection, iterator) + + val serializedBytes = checkpoint.checkpointSerialize(context) + val deserializedCheckpoint = serializedBytes.checkpointDeserialize(context) + + assertThat(deserializedCheckpoint.list).isEqualTo(collection) + assertThat(deserializedCheckpoint.iterator.next().key).isEqualTo(2) + assertThat(deserializedCheckpoint.iterator.hasNext()).isTrue() + } +}