Merge branch 'release/os/4.4' of https://github.com/corda/corda into nnagy-os4.4-os4.5-20200426-2

 Conflicts:
	node-api/src/main/kotlin/net/corda/nodeapi/internal/persistence/HibernateConfiguration.kt
	node-api/src/main/kotlin/net/corda/nodeapi/internal/rpc/client/AMQPClientSerializationScheme.kt
This commit is contained in:
nikinagy 2020-04-24 16:57:57 +01:00
commit 8eab8653cd
3 changed files with 178 additions and 0 deletions

View File

@ -141,6 +141,10 @@ object DefaultKryoCustomizer {
register(ContractUpgradeWireTransaction::class.java, ContractUpgradeWireTransactionSerializer) register(ContractUpgradeWireTransaction::class.java, ContractUpgradeWireTransactionSerializer)
register(ContractUpgradeFilteredTransaction::class.java, ContractUpgradeFilteredTransactionSerializer) register(ContractUpgradeFilteredTransaction::class.java, ContractUpgradeFilteredTransactionSerializer)
addDefaultSerializer(Iterator::class.java) {kryo, type ->
IteratorSerializer(type, CompatibleFieldSerializer<Iterator<*>>(kryo, type).apply { setIgnoreSyntheticFields(false) })
}
for (whitelistProvider in serializationWhitelists) { for (whitelistProvider in serializationWhitelists) {
val types = whitelistProvider.whitelist val types = whitelistProvider.whitelist
require(types.toSet().size == types.size) { require(types.toSet().size == types.size) {

View File

@ -0,0 +1,52 @@
package net.corda.node.serialization.kryo
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import java.lang.reflect.Field
class IteratorSerializer(type: Class<*>, private val serializer: Serializer<Iterator<*>>) : Serializer<Iterator<*>>(false, false) {
private val iterableReferenceField = findField(type, "this\$0")?.apply { isAccessible = true }
private val expectedModCountField = findField(type, "expectedModCount")?.apply { isAccessible = true }
private val iterableReferenceFieldType = iterableReferenceField?.type
private val modCountField = when (iterableReferenceFieldType) {
null -> null
else -> findField(iterableReferenceFieldType, "modCount")?.apply { isAccessible = true }
}
override fun write(kryo: Kryo, output: Output, obj: Iterator<*>) {
serializer.write(kryo, output, obj)
}
override fun read(kryo: Kryo, input: Input, type: Class<Iterator<*>>): Iterator<*> {
val iterator = serializer.read(kryo, input, type)
return fixIterator(iterator)
}
private fun fixIterator(iterator: Iterator<*>) : Iterator<*> {
// Set expectedModCount of iterator
val iterableInstance = iterableReferenceField?.get(iterator) ?: return iterator
val modCountValue = modCountField?.getInt(iterableInstance) ?: return iterator
expectedModCountField?.setInt(iterator, modCountValue)
return iterator
}
/**
* Find field in clazz or any superclass
*/
private fun findField(clazz: Class<*>, fieldName: String): Field? {
return clazz.declaredFields.firstOrNull { x -> x.name == fieldName } ?: when {
clazz.superclass != null -> {
// Look in superclasses
findField(clazz.superclass, fieldName)
}
else -> null // Not found
}
}
}

View File

@ -0,0 +1,122 @@
package net.corda.node.serialization.kryo
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.serialization.EncodingWhitelist
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.serialization.internal.AllWhitelist
import net.corda.serialization.internal.CheckpointSerializationContextImpl
import net.corda.serialization.internal.CordaSerializationEncoding
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import net.corda.testing.internal.rigorousMock
import org.assertj.core.api.Assertions.assertThat
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters
import java.util.*
import kotlin.collections.ArrayList
import kotlin.collections.HashMap
import kotlin.collections.HashSet
import kotlin.collections.LinkedHashMap
import kotlin.collections.LinkedHashSet
@RunWith(Parameterized::class)
class ArrayListItrConcurrentModificationException(private val compression: CordaSerializationEncoding?) {
companion object {
@Parameters(name = "{0}")
@JvmStatic
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
}
@get:Rule
val serializationRule = CheckpointSerializationEnvironmentRule(inheritable = true)
private lateinit var context: CheckpointSerializationContext
@Before
fun setup() {
context = CheckpointSerializationContextImpl(
deserializationClassLoader = javaClass.classLoader,
whitelist = AllWhitelist,
properties = emptyMap(),
objectReferencesEnabled = true,
encoding = compression,
encodingWhitelist = rigorousMock<EncodingWhitelist>().also {
if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression)
})
}
@Test(timeout=300_000)
fun `ArrayList iterator can checkpoint without error`() {
runTestWithCollection(ArrayList())
}
@Test(timeout=300_000)
fun `HashSet iterator can checkpoint without error`() {
runTestWithCollection(HashSet())
}
@Test(timeout=300_000)
fun `LinkedHashSet iterator can checkpoint without error`() {
runTestWithCollection(LinkedHashSet())
}
@Test(timeout=300_000)
fun `HashMap iterator can checkpoint without error`() {
runTestWithCollection(HashMap())
}
@Test(timeout=300_000)
fun `LinkedHashMap iterator can checkpoint without error`() {
runTestWithCollection(LinkedHashMap())
}
@Test(timeout=300_000)
fun `LinkedList iterator can checkpoint without error`() {
runTestWithCollection(LinkedList())
}
private data class TestCheckpoint<C,I>(val list: C, val iterator: I)
private fun runTestWithCollection(collection: MutableCollection<Int>) {
for (i in 1..100) {
collection.add(i)
}
val iterator = collection.iterator()
iterator.next()
val checkpoint = TestCheckpoint(collection, iterator)
val serializedBytes = checkpoint.checkpointSerialize(context)
val deserializedCheckpoint = serializedBytes.checkpointDeserialize(context)
assertThat(deserializedCheckpoint.list).isEqualTo(collection)
assertThat(deserializedCheckpoint.iterator.next()).isEqualTo(2)
assertThat(deserializedCheckpoint.iterator.hasNext()).isTrue()
}
private fun runTestWithCollection(collection: MutableMap<Int, Int>) {
for (i in 1..100) {
collection[i] = i
}
val iterator = collection.iterator()
iterator.next()
val checkpoint = TestCheckpoint(collection, iterator)
val serializedBytes = checkpoint.checkpointSerialize(context)
val deserializedCheckpoint = serializedBytes.checkpointDeserialize(context)
assertThat(deserializedCheckpoint.list).isEqualTo(collection)
assertThat(deserializedCheckpoint.iterator.next().key).isEqualTo(2)
assertThat(deserializedCheckpoint.iterator.hasNext()).isTrue()
}
}