From 28f00ce92a17483ab07dc848af398ece90d30b0f Mon Sep 17 00:00:00 2001 From: Rick Parker Date: Wed, 22 Apr 2020 16:21:39 +0100 Subject: [PATCH] CORDA-3701 Fix bugs in some iterator checkpoint serializers (#6135) * CORDA-3701 Fix bugs in some iterator checkpoint serializers * Added some more tests and tidied up implementation some more. * Fix imports to be detekt compliant * Add timeouts to tests --- .../kryo/CustomIteratorSerializers.kt | 31 +++++++++---- .../serialization/kryo/KryoCheckpointTest.kt | 43 ++++++++++++++++++- 2 files changed, 64 insertions(+), 10 deletions(-) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomIteratorSerializers.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomIteratorSerializers.kt index ff50a1c137..b02779fae8 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomIteratorSerializers.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/kryo/CustomIteratorSerializers.kt @@ -6,6 +6,8 @@ import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output import java.lang.reflect.Constructor import java.lang.reflect.Field +import java.util.LinkedHashMap +import java.util.LinkedHashSet import java.util.LinkedList /** @@ -37,30 +39,43 @@ internal object LinkedHashMapIteratorSerializer : Serializer>() { return when (type) { KEY_ITERATOR_CLASS -> { val current = (kryo.readClassAndObject(input) as? Map.Entry<*, *>)?.key - outerMap.keys.iterator().returnToIteratorLocation(current) + outerMap.keys.iterator().returnToIteratorLocation(kryo, current) } VALUE_ITERATOR_CLASS -> { val current = (kryo.readClassAndObject(input) as? Map.Entry<*, *>)?.value - outerMap.values.iterator().returnToIteratorLocation(current) + outerMap.values.iterator().returnToIteratorLocation(kryo, current) } MAP_ITERATOR_CLASS -> { val current = (kryo.readClassAndObject(input) as? Map.Entry<*, *>) - outerMap.iterator().returnToIteratorLocation(current) + outerMap.iterator().returnToIteratorLocation(kryo, current) } else -> throw IllegalStateException("Invalid type") } } - private fun Iterator<*>.returnToIteratorLocation(current: Any?) : Iterator<*> { + private fun Iterator<*>.returnToIteratorLocation(kryo: Kryo, current: Any?): Iterator<*> { while (this.hasNext()) { val key = this.next() - @Suppress("SuspiciousEqualsCombination") - if (current == null || key === current || key == current) { - break - } + if (iteratedObjectsEqual(kryo, key, current)) break } return this } + + private fun iteratedObjectsEqual(kryo: Kryo, a: Any?, b: Any?): Boolean = if (a == null || b == null) { + a == b + } else { + a === b || mapEntriesEqual(kryo, a, b) || kryoOptimisesAwayReferencesButEqual(kryo, a, b) + } + + /** + * Kryo can substitute brand new created instances for some types during deserialization, making the identity check fail. + * Fall back to equality for those. + */ + private fun kryoOptimisesAwayReferencesButEqual(kryo: Kryo, a: Any, b: Any) = + (!kryo.referenceResolver.useReferences(a.javaClass) && !kryo.referenceResolver.useReferences(b.javaClass) && a == b) + + private fun mapEntriesEqual(kryo: Kryo, a: Any, b: Any) = + (a is Map.Entry<*, *> && b is Map.Entry<*, *> && iteratedObjectsEqual(kryo, a.key, b.key)) } /** diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointTest.kt index 70ab07f54a..e0906294b4 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/kryo/KryoCheckpointTest.kt @@ -1,5 +1,6 @@ package net.corda.nodeapi.internal.serialization.kryo +import org.junit.Ignore import org.junit.Test import org.junit.jupiter.api.assertDoesNotThrow import java.util.LinkedList @@ -47,12 +48,16 @@ class KryoCheckpointTest { @Test(timeout=300_000) fun `linked hash map with null values can checkpoint without error`() { val dummyMap = linkedMapOf().apply { - put(null, null) + put("foo", 2L) + put(null, null) + put("bar", 3L) } val it = dummyMap.iterator() val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT) val itKeys = dummyMap.keys.iterator() + itKeys.next() + itKeys.next() val bytesKeys = KryoCheckpointSerializer.serialize(itKeys, KRYO_CHECKPOINT_CONTEXT) val itValues = dummyMap.values.iterator() @@ -60,7 +65,8 @@ class KryoCheckpointTest { assertDoesNotThrow { KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT) - KryoCheckpointSerializer.deserialize(bytesKeys, itKeys.javaClass, KRYO_CHECKPOINT_CONTEXT) + val desItKeys = KryoCheckpointSerializer.deserialize(bytesKeys, itKeys.javaClass, KRYO_CHECKPOINT_CONTEXT) + assertEquals("bar", desItKeys.next()) KryoCheckpointSerializer.deserialize(bytesValues, itValues.javaClass, KRYO_CHECKPOINT_CONTEXT) } } @@ -97,6 +103,39 @@ class KryoCheckpointTest { assertEquals(testSize, lastValue) } + @Test(timeout = 300_000) + fun `linked hash map values can checkpoint without error, even with repeats`() { + var lastValue = "0" + val dummyMap = linkedMapOf() + for (i in 0..testSize) { + dummyMap[i.toString()] = (i % 10).toString() + } + var it = dummyMap.values.iterator() + while (it.hasNext()) { + lastValue = it.next() + val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT) + it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT) + } + assertEquals((testSize % 10).toString(), lastValue) + } + + @Ignore("Kryo optimizes boxed primitives so this does not work. Need to customise ReferenceResolver to stop it doing it.") + @Test(timeout = 300_000) + fun `linked hash map values can checkpoint without error, even with repeats for boxed primitives`() { + var lastValue = 0L + val dummyMap = linkedMapOf() + for (i in 0..testSize) { + dummyMap[i.toString()] = (i % 10) + } + var it = dummyMap.values.iterator() + while (it.hasNext()) { + lastValue = it.next() + val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT) + it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT) + } + assertEquals(testSize % 10, lastValue) + } + /** * This test just ensures that the checkpoints still work in light of [LinkedHashMapEntrySerializer]. */