CORDA-2006: Simplify checkpoint serialization (#4042)

* CORDA-2006: Simplify checkpoint serialization

* Supply rule to KryoTest
This commit is contained in:
Dominic Fox
2018-10-08 13:39:28 +01:00
committed by GitHub
parent c88d3d8c1b
commit d9ea19855f
23 changed files with 186 additions and 311 deletions

View File

@ -4,7 +4,6 @@ import net.corda.core.cordapp.Cordapp
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.node.services.api.CheckpointStorage
@ -21,7 +20,7 @@ object CheckpointVerifier {
*/
fun verifyCheckpointsCompatible(checkpointStorage: CheckpointStorage, currentCordapps: List<Cordapp>, platformVersion: Int, serviceHub: ServiceHub, tokenizableServices: List<Any>) {
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZER, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
checkpointStorage.getAllCheckpoints().forEach { (_, serializedCheckpoint) ->

View File

@ -21,8 +21,7 @@ import net.corda.core.messaging.RPCOps
import net.corda.core.node.NetworkParameters
import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
@ -38,7 +37,7 @@ import net.corda.node.internal.security.RPCSecurityManagerImpl
import net.corda.node.internal.security.RPCSecurityManagerWithAdditionalUser
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoSerializationScheme
import net.corda.node.serialization.kryo.KryoCheckpointSerializer
import net.corda.node.services.Permissions
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal
@ -470,17 +469,19 @@ open class Node(configuration: NodeConfiguration,
private fun initialiseSerialization() {
if (!initialiseSerialization) return
val classloader = cordappLoader.appClassLoader
nodeSerializationEnv = SerializationEnvironmentImpl(
nodeSerializationEnv = SerializationEnvironment.with(
SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps))
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps))
},
checkpointSerializationFactory = CheckpointSerializationFactory(KryoSerializationScheme),
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),
rpcClientContext = if (configuration.shouldInitCrashShell()) AMQP_RPC_CLIENT_CONTEXT.withClassLoader(classloader) else null, //even Shell embeded in the node connects via RPC to the node
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader),
rpcClientContext = if (configuration.shouldInitCrashShell()) AMQP_RPC_CLIENT_CONTEXT.withClassLoader(classloader) else null) //even Shell embeded in the node connects via RPC to the node
checkpointSerializer = KryoCheckpointSerializer,
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader)
)
}
/** Starts a blocking event loop for message dispatch. */

View File

@ -12,10 +12,9 @@ import com.esotericsoftware.kryo.serializers.ClosureSerializer
import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationScheme
import net.corda.core.serialization.internal.CheckpointSerializer
import net.corda.core.utilities.ByteSequence
import net.corda.serialization.internal.*
import java.security.PublicKey
import java.util.concurrent.ConcurrentHashMap
val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0))
@ -31,7 +30,7 @@ private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>()
override fun read(kryo: Kryo, input: Input, type: Class<AutoCloseable>) = throw IllegalStateException("Should not reach here!")
}
object KryoSerializationScheme : CheckpointSerializationScheme {
object KryoCheckpointSerializer : CheckpointSerializer {
private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>()
private fun getPool(context: CheckpointSerializationContext): KryoPool {

View File

@ -127,7 +127,7 @@ class SingleThreadedStateMachineManager(
override fun start(tokenizableServices: List<Any>) {
checkQuasarJavaAgentPresence()
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZER, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
this.checkpointSerializationContext = checkpointSerializationContext
this.actionExecutor = makeActionExecutor(checkpointSerializationContext)

View File

@ -22,7 +22,7 @@ import net.corda.core.internal.notary.isConsumedByTheSameTx
import net.corda.core.internal.notary.validateTimeWindow
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.contextLogger
@ -201,7 +201,7 @@ class RaftTransactionCommitLog<E, EK>(
class CordaKryoSerializer<T : Any> : TypeSerializer<T> {
private val context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY)
private val factory = CheckpointSerializationFactory.defaultFactory
private val checkpointSerializer = CheckpointSerializationDefaults.CHECKPOINT_SERIALIZER
override fun write(obj: T, buffer: BufferOutput<*>, serializer: Serializer) {
val serialized = obj.checkpointSerialize(context = context)
@ -213,7 +213,7 @@ class RaftTransactionCommitLog<E, EK>(
val size = buffer.readInt()
val serialized = ByteArray(size)
buffer.read(serialized)
return factory.deserialize(ByteSequence.of(serialized), type, context)
return checkpointSerializer.deserialize(ByteSequence.of(serialized), type, context)
}
}
}

View File

@ -13,7 +13,6 @@ import net.corda.core.crypto.*
import net.corda.core.internal.FetchDataFlow
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence
@ -23,11 +22,13 @@ import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.serialization.internal.*
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.TestIdentity
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import net.corda.testing.internal.rigorousMock
import org.assertj.core.api.Assertions.*
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@ -48,12 +49,12 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
}
private lateinit var factory: CheckpointSerializationFactory
@get:Rule
val serializationRule = CheckpointSerializationEnvironmentRule()
private lateinit var context: CheckpointSerializationContext
@Before
fun setup() {
factory = CheckpointSerializationFactory(KryoSerializationScheme)
context = CheckpointSerializationContextImpl(
javaClass.classLoader,
AllWhitelist,
@ -69,15 +70,15 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `simple data class`() {
val birthday = Instant.parse("1984-04-17T00:30:00.00Z")
val mike = Person("mike", birthday)
val bits = mike.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("mike", birthday))
val bits = mike.checkpointSerialize(context)
assertThat(bits.checkpointDeserialize(context)).isEqualTo(Person("mike", birthday))
}
@Test
fun `null values`() {
val bob = Person("bob", null)
val bits = bob.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("bob", null))
val bits = bob.checkpointSerialize(context)
assertThat(bits.checkpointDeserialize(context)).isEqualTo(Person("bob", null))
}
@Test
@ -85,10 +86,10 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val noReferencesContext = context.withoutReferences()
val obj : ByteSequence = Ints.toByteArray(0x01234567).sequence()
val originalList : ArrayList<ByteSequence> = ArrayList<ByteSequence>().apply { this += obj }
val deserialisedList = originalList.checkpointSerialize(factory, noReferencesContext).checkpointDeserialize(factory, noReferencesContext)
val deserialisedList = originalList.checkpointSerialize(noReferencesContext).checkpointDeserialize(noReferencesContext)
originalList += obj
deserialisedList += obj
assertThat(deserialisedList.checkpointSerialize(factory, noReferencesContext)).isEqualTo(originalList.checkpointSerialize(factory, noReferencesContext))
assertThat(deserialisedList.checkpointSerialize(noReferencesContext)).isEqualTo(originalList.checkpointSerialize(noReferencesContext))
}
@Test
@ -105,14 +106,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
this += instant
this += instant
}
assertThat(listWithSameInstances.checkpointSerialize(factory, noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(factory, noReferencesContext))
assertThat(listWithSameInstances.checkpointSerialize(noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(noReferencesContext))
}
@Test
fun `cyclic object graph`() {
val cyclic = Cyclic(3)
val bits = cyclic.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(cyclic)
val bits = cyclic.checkpointSerialize(context)
assertThat(bits.checkpointDeserialize(context)).isEqualTo(cyclic)
}
@Test
@ -124,7 +125,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
signature.verify(bitsToSign)
assertThatThrownBy { signature.verify(wrongBits) }
val deserialisedKeyPair = keyPair.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val deserialisedKeyPair = keyPair.checkpointSerialize(context).checkpointDeserialize(context)
val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign)
deserialisedSignature.verify(bitsToSign)
assertThatThrownBy { deserialisedSignature.verify(wrongBits) }
@ -132,28 +133,28 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `write and read Kotlin object singleton`() {
val serialised = TestSingleton.checkpointSerialize(factory, context)
val deserialised = serialised.checkpointDeserialize(factory, context)
val serialised = TestSingleton.checkpointSerialize(context)
val deserialised = serialised.checkpointDeserialize(context)
assertThat(deserialised).isSameAs(TestSingleton)
}
@Test
fun `check Kotlin EmptyList can be serialised`() {
val deserialisedList: List<Int> = emptyList<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val deserialisedList: List<Int> = emptyList<Int>().checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(0, deserialisedList.size)
assertEquals<Any>(Collections.emptyList<Int>().javaClass, deserialisedList.javaClass)
}
@Test
fun `check Kotlin EmptySet can be serialised`() {
val deserialisedSet: Set<Int> = emptySet<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val deserialisedSet: Set<Int> = emptySet<Int>().checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(0, deserialisedSet.size)
assertEquals<Any>(Collections.emptySet<Int>().javaClass, deserialisedSet.javaClass)
}
@Test
fun `check Kotlin EmptyMap can be serialised`() {
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(0, deserialisedMap.size)
assertEquals<Any>(Collections.emptyMap<Int, Int>().javaClass, deserialisedMap.javaClass)
}
@ -161,7 +162,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `InputStream serialisation`() {
val rubbish = ByteArray(12345) { (it * it * 0.12345).toByte() }
val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(context).checkpointDeserialize(context)
for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
@ -171,7 +172,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `InputStream serialisation does not write trailing garbage`() {
val byteArrays = listOf("123", "456").map { it.toByteArray() }
val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(factory, context).checkpointDeserialize(factory, context).iterator()
val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(context).checkpointDeserialize(context).iterator()
byteArrays.forEach { assertArrayEquals(it, streams.next().readBytes()) }
assertFalse(streams.hasNext())
}
@ -182,8 +183,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val testBytes = testString.toByteArray()
val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID))
val serializedMetaData = meta.checkpointSerialize(factory, context).bytes
val meta2 = serializedMetaData.checkpointDeserialize<SignableData>(factory, context)
val serializedMetaData = meta.checkpointSerialize(context).bytes
val meta2 = serializedMetaData.checkpointDeserialize<SignableData>(context)
assertEquals(meta2, meta)
}
@ -191,7 +192,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `serialize - deserialize Logger`() {
val storageContext: CheckpointSerializationContext = context
val logger = LoggerFactory.getLogger("aName")
val logger2 = logger.checkpointSerialize(factory, storageContext).checkpointDeserialize(factory, storageContext)
val logger2 = logger.checkpointSerialize(storageContext).checkpointDeserialize(storageContext)
assertEquals(logger.name, logger2.name)
assertTrue(logger === logger2)
}
@ -203,7 +204,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
SecureHash.sha256(rubbish),
rubbish.size,
rubbish.inputStream()
).checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
).checkpointSerialize(context).checkpointDeserialize(context)
for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
@ -230,8 +231,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32
))
val serializedBytes = expected.checkpointSerialize(factory, context)
val actual = serializedBytes.checkpointDeserialize(factory, context)
val serializedBytes = expected.checkpointSerialize(context)
val actual = serializedBytes.checkpointDeserialize(context)
assertEquals(expected, actual)
}
@ -278,14 +279,13 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
}
}
Tmp()
val factory = CheckpointSerializationFactory(KryoSerializationScheme)
val context = CheckpointSerializationContextImpl(
javaClass.classLoader,
AllWhitelist,
emptyMap(),
true,
null)
pt.checkpointSerialize(factory, context)
pt.checkpointSerialize(context)
}
@Test
@ -293,7 +293,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val exception = IllegalArgumentException("fooBar")
val toBeSuppressedOnSenderSide = IllegalStateException("bazz1")
exception.addSuppressed(toBeSuppressedOnSenderSide)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val exception2 = exception.checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(exception.message, exception2.message)
assertEquals(1, exception2.suppressed.size)
@ -308,7 +308,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `serialize - deserialize Exception no suppressed`() {
val exception = IllegalArgumentException("fooBar")
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val exception2 = exception.checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(exception.message, exception2.message)
assertEquals(0, exception2.suppressed.size)
@ -322,7 +322,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `serialize - deserialize HashNotFound`() {
val randomHash = SecureHash.randomSHA256()
val exception = FetchDataFlow.HashNotFound(randomHash)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val exception2 = exception.checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(randomHash, exception2.requested)
}
@ -330,17 +330,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `compression has the desired effect`() {
compression ?: return
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
val compressed = data.checkpointSerialize(factory, context)
val compressed = data.checkpointSerialize(context)
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
assertArrayEquals(data, compressed.checkpointDeserialize(factory, context))
assertArrayEquals(data, compressed.checkpointDeserialize(context))
}
@Test
fun `a particular encoding can be banned for deserialization`() {
compression ?: return
doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression)
val compressed = "whatever".checkpointSerialize(factory, context)
catchThrowable { compressed.checkpointDeserialize(factory, context) }.run {
val compressed = "whatever".checkpointSerialize(context)
catchThrowable { compressed.checkpointDeserialize(context) }.run {
assertSame<Any>(KryoException::class.java, javaClass)
assertEquals(encodingNotPermittedFormat.format(compression), message)
}
@ -351,8 +351,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
class Holder(val holder: ByteArray)
val obj = Holder(ByteArray(20000))
val uncompressedSize = obj.checkpointSerialize(factory, context.withEncoding(null)).size
val compressedSize = obj.checkpointSerialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
val uncompressedSize = obj.checkpointSerialize(context.withEncoding(null)).size
val compressedSize = obj.checkpointSerialize(context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
// If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade.
assertEquals(20222, uncompressedSize)
assertEquals(1111, compressedSize)