CORDA-1391: Separate out Checkpoint serialization (#3922)

* Separate out Checkpoint serialization

* Update kdocs

* Rename checkpoint serialization extension methods

* Fix bungled rename

* Limit API changes

* Simplify CheckpointSerializationFactory

* Add CheckpointSerializationScheme to API checker

* CheckpointSerializationScheme should not be implemented

* Move checkpoint serialisation to internal package

* Remove CheckpointSerializationScheme from api-current

* Quarantine internal classes

* Remove checkpoint context from public API

* Remove checkpoint context from public API

* Fix test failures

* Completely decouple SerializationTestHelpers and CheckpointSerializationTestHelpers

* Remove CHECKPOINT use case

* Remove stray reference to checkpoint use case

* Fix broken test
This commit is contained in:
Dominic Fox
2018-09-19 14:23:29 +01:00
committed by GitHub
parent d10892c09e
commit 98c92ef16f
37 changed files with 677 additions and 253 deletions

View File

@ -5,11 +5,12 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.statemachine.SubFlow
import net.corda.node.services.statemachine.SubFlowVersion
import net.corda.serialization.internal.SerializeAsTokenContextImpl
import net.corda.serialization.internal.CheckpointSerializeAsTokenContextImpl
import net.corda.serialization.internal.withTokenContext
object CheckpointVerifier {
@ -19,13 +20,13 @@ object CheckpointVerifier {
* @throws CheckpointIncompatibleException if any offending checkpoint is found.
*/
fun verifyCheckpointsCompatible(checkpointStorage: CheckpointStorage, currentCordapps: List<Cordapp>, platformVersion: Int, serviceHub: ServiceHub, tokenizableServices: List<Any>) {
val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
checkpointStorage.getAllCheckpoints().forEach { (_, serializedCheckpoint) ->
val checkpoint = try {
serializedCheckpoint.deserialize(context = checkpointSerializationContext)
serializedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext)
} catch (e: Exception) {
throw CheckpointIncompatibleException.CannotBeDeserialisedException(e)
}

View File

@ -21,6 +21,7 @@ import net.corda.core.messaging.RPCOps
import net.corda.core.node.NetworkParameters
import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
@ -37,7 +38,7 @@ import net.corda.node.internal.security.RPCSecurityManagerImpl
import net.corda.node.internal.security.RPCSecurityManagerWithAdditionalUser
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoServerSerializationScheme
import net.corda.node.serialization.kryo.KryoSerializationScheme
import net.corda.node.services.Permissions
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal
@ -449,8 +450,8 @@ open class Node(configuration: NodeConfiguration,
SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps))
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps))
registerScheme(KryoServerSerializationScheme())
},
checkpointSerializationFactory = CheckpointSerializationFactory(KryoSerializationScheme),
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),

View File

@ -8,8 +8,8 @@ import com.esotericsoftware.kryo.util.DefaultClassResolver
import com.esotericsoftware.kryo.util.Util
import net.corda.core.internal.kotlinObjectInstance
import net.corda.core.internal.writer
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializationContext
import net.corda.core.utilities.contextLogger
import net.corda.serialization.internal.AttachmentsClassLoader
import net.corda.serialization.internal.MutableClassWhitelist
@ -25,7 +25,7 @@ import java.util.*
/**
* Corda specific class resolver which enables extra customisation for the purposes of serialization using Kryo
*/
class CordaClassResolver(serializationContext: SerializationContext) : DefaultClassResolver() {
class CordaClassResolver(serializationContext: CheckpointSerializationContext) : DefaultClassResolver() {
val whitelist: ClassWhitelist = TransientClassWhiteList(serializationContext.whitelist)
// These classes are assignment-compatible Java equivalents of Kotlin classes.

View File

@ -14,12 +14,11 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.TransactionSignature
import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.UseCase.Checkpoint
import net.corda.core.serialization.SerializationContext.UseCase.Storage
import net.corda.core.serialization.SerializeAsTokenContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.*
import net.corda.core.utilities.OpaqueBytes
import net.corda.serialization.internal.checkUseCase
import net.corda.serialization.internal.serializationContextKey
import org.slf4j.Logger
import org.slf4j.LoggerFactory
@ -275,16 +274,9 @@ object SignedTransactionSerializer : Serializer<SignedTransaction>() {
}
}
sealed class UseCaseSerializer<T>(private val allowedUseCases: EnumSet<SerializationContext.UseCase>) : Serializer<T>() {
protected fun checkUseCase() {
net.corda.serialization.internal.checkUseCase(allowedUseCases)
}
}
@ThreadSafe
object PrivateKeySerializer : UseCaseSerializer<PrivateKey>(EnumSet.of(Storage, Checkpoint)) {
object PrivateKeySerializer : Serializer<PrivateKey>() {
override fun write(kryo: Kryo, output: Output, obj: PrivateKey) {
checkUseCase()
output.writeBytesWithLength(obj.encoded)
}

View File

@ -10,10 +10,9 @@ import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.esotericsoftware.kryo.serializers.ClosureSerializer
import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationScheme
import net.corda.core.utilities.ByteSequence
import net.corda.serialization.internal.*
import java.security.PublicKey
@ -32,46 +31,30 @@ private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>()
override fun read(kryo: Kryo, input: Input, type: Class<AutoCloseable>) = throw IllegalStateException("Should not reach here!")
}
abstract class AbstractKryoSerializationScheme : SerializationScheme {
object KryoSerializationScheme : CheckpointSerializationScheme {
private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>()
protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool
protected abstract fun rpcServerKryoPool(context: SerializationContext): KryoPool
// this can be overridden in derived serialization schemes
protected open val publicKeySerializer: Serializer<PublicKey> = PublicKeySerializer
private fun getPool(context: SerializationContext): KryoPool {
private fun getPool(context: CheckpointSerializationContext): KryoPool {
return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) {
when (context.useCase) {
SerializationContext.UseCase.Checkpoint ->
KryoPool.Builder {
val serializer = Fiber.getFiberSerializer(false) as KryoSerializer
val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) }
// TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that
val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true }
serializer.kryo.apply {
field.set(this, classResolver)
// don't allow overriding the public key serializer for checkpointing
DefaultKryoCustomizer.customize(this)
addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector)
register(ClosureSerializer.Closure::class.java, CordaClosureSerializer)
classLoader = it.second
}
}.build()
SerializationContext.UseCase.RPCClient ->
rpcClientKryoPool(context)
SerializationContext.UseCase.RPCServer ->
rpcServerKryoPool(context)
else ->
KryoPool.Builder {
DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(context)), publicKeySerializer).apply { classLoader = it.second }
}.build()
}
KryoPool.Builder {
val serializer = Fiber.getFiberSerializer(false) as KryoSerializer
val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) }
// TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that
val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true }
serializer.kryo.apply {
field.set(this, classResolver)
// don't allow overriding the public key serializer for checkpointing
DefaultKryoCustomizer.customize(this)
addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector)
register(ClosureSerializer.Closure::class.java, CordaClosureSerializer)
classLoader = it.second
}
}.build()
}
}
private fun <T : Any> SerializationContext.kryo(task: Kryo.() -> T): T {
private fun <T : Any> CheckpointSerializationContext.kryo(task: Kryo.() -> T): T {
return getPool(this).run { kryo ->
kryo.context.ensureCapacity(properties.size)
properties.forEach { kryo.context.put(it.key, it.value) }
@ -83,7 +66,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
}
}
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
val dataBytes = kryoMagic.consume(byteSequence)
?: throw KryoException("Serialized bytes header does not match expected format.")
return context.kryo {
@ -111,7 +94,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
}
}
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
override fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return context.kryo {
SerializedBytes(kryoOutput {
kryoMagic.writeTo(this)
@ -131,13 +114,11 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
}
}
val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl(
kryoMagic,
val KRYO_CHECKPOINT_CONTEXT = CheckpointSerializationContextImpl(
SerializationDefaults.javaClass.classLoader,
QuasarWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.Checkpoint,
null,
AlwaysAcceptEncodingWhitelist
)

View File

@ -1,14 +0,0 @@
package net.corda.node.serialization.kryo
import com.esotericsoftware.kryo.pool.KryoPool
import net.corda.core.serialization.SerializationContext
import net.corda.serialization.internal.CordaSerializationMagic
class KryoServerSerializationScheme : AbstractKryoSerializationScheme() {
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean {
return magic == kryoMagic && target == SerializationContext.UseCase.Checkpoint
}
override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
}

View File

@ -4,9 +4,9 @@ import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import com.codahale.metrics.*
import net.corda.core.internal.concurrent.thenMatch
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace
import net.corda.node.services.api.CheckpointStorage
@ -27,7 +27,7 @@ class ActionExecutorImpl(
private val checkpointStorage: CheckpointStorage,
private val flowMessaging: FlowMessaging,
private val stateMachineManager: StateMachineManagerInternal,
private val checkpointSerializationContext: SerializationContext,
private val checkpointSerializationContext: CheckpointSerializationContext,
metrics: MetricRegistry
) : ActionExecutor {
@ -237,7 +237,7 @@ class ActionExecutorImpl(
}
private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> {
return checkpoint.serialize(context = checkpointSerializationContext)
return checkpoint.checkpointSerialize(context = checkpointSerializationContext)
}
private fun cancelFlowTimeout(action: Action.CancelFlowTimeout) {

View File

@ -12,8 +12,8 @@ import net.corda.core.cordapp.Cordapp
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.internal.*
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.serialize
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.Try
import net.corda.core.utilities.debug
import net.corda.core.utilities.trace
@ -69,7 +69,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val actionExecutor: ActionExecutor,
val stateMachine: StateMachine,
val serviceHub: ServiceHubInternal,
val checkpointSerializationContext: SerializationContext,
val checkpointSerializationContext: CheckpointSerializationContext,
val unfinishedFibers: ReusableLatch
)
@ -369,7 +369,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
Event.Suspend(
ioRequest = ioRequest,
maySkipCheckpoint = skipPersistingCheckpoint,
fiber = this.serialize(context = serializationContext.value)
fiber = this.checkpointSerialize(context = serializationContext.value)
)
} catch (throwable: Throwable) {
Event.Error(throwable)

View File

@ -19,6 +19,10 @@ import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
@ -36,7 +40,7 @@ import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.injectOldProgressTracker
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
import net.corda.serialization.internal.SerializeAsTokenContextImpl
import net.corda.serialization.internal.CheckpointSerializeAsTokenContextImpl
import net.corda.serialization.internal.withTokenContext
import org.apache.activemq.artemis.utils.ReusableLatch
import rx.Observable
@ -103,7 +107,7 @@ class SingleThreadedStateMachineManager(
private val transitionExecutor = makeTransitionExecutor()
private val ourSenderUUID = serviceHub.networkService.ourSenderUUID
private var checkpointSerializationContext: SerializationContext? = null
private var checkpointSerializationContext: CheckpointSerializationContext? = null
private var actionExecutor: ActionExecutor? = null
override val allStateMachines: List<FlowLogic<*>>
@ -122,8 +126,8 @@ class SingleThreadedStateMachineManager(
override fun start(tokenizableServices: List<Any>) {
checkQuasarJavaAgentPresence()
val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
this.checkpointSerializationContext = checkpointSerializationContext
this.actionExecutor = makeActionExecutor(checkpointSerializationContext)
@ -531,7 +535,7 @@ class SingleThreadedStateMachineManager(
val resultFuture = openFuture<Any?>()
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
flowLogic.stateMachine = flowStateMachineImpl
val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!)
val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext!!)
val flowCorDappVersion = createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion)
@ -613,7 +617,7 @@ class SingleThreadedStateMachineManager(
private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes<Checkpoint>): Checkpoint? {
return try {
serializedCheckpoint.deserialize(context = checkpointSerializationContext!!)
serializedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext!!)
} catch (exception: Throwable) {
logger.error("Encountered unrestorable checkpoint!", exception)
null
@ -658,7 +662,7 @@ class SingleThreadedStateMachineManager(
val resultFuture = openFuture<Any?>()
val fiber = when (flowState) {
is FlowState.Unstarted -> {
val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!)
val logic = flowState.frozenFlowLogic.checkpointDeserialize(context = checkpointSerializationContext!!)
val state = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
@ -677,7 +681,7 @@ class SingleThreadedStateMachineManager(
fiber
}
is FlowState.Started -> {
val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!)
val fiber = flowState.frozenFiber.checkpointDeserialize(context = checkpointSerializationContext!!)
val state = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
@ -742,7 +746,7 @@ class SingleThreadedStateMachineManager(
}
}
private fun makeActionExecutor(checkpointSerializationContext: SerializationContext): ActionExecutor {
private fun makeActionExecutor(checkpointSerializationContext: CheckpointSerializationContext): ActionExecutor {
return ActionExecutorImpl(
serviceHub,
checkpointStorage,

View File

@ -2,9 +2,9 @@ package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.ActionExecutor
import net.corda.node.services.statemachine.Event
@ -68,7 +68,7 @@ class FiberDeserializationChecker {
private val jobQueue = LinkedBlockingQueue<Job>()
private var foundUnrestorableFibers: Boolean = false
fun start(checkpointSerializationContext: SerializationContext) {
fun start(checkpointSerializationContext: CheckpointSerializationContext) {
require(checkerThread == null)
checkerThread = thread(name = "FiberDeserializationChecker") {
while (true) {
@ -76,7 +76,7 @@ class FiberDeserializationChecker {
when (job) {
is Job.Check -> {
try {
job.serializedFiber.deserialize(context = checkpointSerializationContext)
job.serializedFiber.checkpointDeserialize(context = checkpointSerializationContext)
} catch (throwable: Throwable) {
log.error("Encountered unrestorable checkpoint!", throwable)
foundUnrestorableFibers = true

View File

@ -20,10 +20,10 @@ import net.corda.core.flows.StateConsumptionDetails
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.notary.isConsumedByTheSameTx
import net.corda.core.internal.notary.validateTimeWindow
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
@ -200,11 +200,11 @@ class RaftTransactionCommitLog<E, EK>(
}
class CordaKryoSerializer<T : Any> : TypeSerializer<T> {
private val context = SerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY)
private val factory = SerializationFactory.defaultFactory
private val context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY)
private val factory = CheckpointSerializationFactory.defaultFactory
override fun write(obj: T, buffer: BufferOutput<*>, serializer: Serializer) {
val serialized = obj.serialize(context = context)
val serialized = obj.checkpointSerialize(context = context)
buffer.writeInt(serialized.size)
buffer.write(serialized.bytes)
}

View File

@ -5,7 +5,6 @@ import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.KryoSerializable
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.primitives.Ints
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
@ -13,6 +12,10 @@ import net.corda.core.contracts.PrivacySalt
import net.corda.core.crypto.*
import net.corda.core.internal.FetchDataFlow
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.sequence
@ -36,16 +39,6 @@ import java.util.*
import kotlin.collections.ArrayList
import kotlin.test.*
class TestScheme : AbstractKryoSerializationScheme() {
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean {
return magic == kryoMagic && target != SerializationContext.UseCase.RPCClient
}
override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
}
@RunWith(Parameterized::class)
class KryoTests(private val compression: CordaSerializationEncoding?) {
companion object {
@ -55,18 +48,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
}
private lateinit var factory: SerializationFactory
private lateinit var context: SerializationContext
private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: CheckpointSerializationContext
@Before
fun setup() {
factory = SerializationFactoryImpl().apply { registerScheme(TestScheme()) }
context = SerializationContextImpl(kryoMagic,
factory = CheckpointSerializationFactory(KryoSerializationScheme)
context = CheckpointSerializationContextImpl(
javaClass.classLoader,
AllWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.Storage,
compression,
rigorousMock<EncodingWhitelist>().also {
if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression)
@ -77,15 +69,15 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `simple data class`() {
val birthday = Instant.parse("1984-04-17T00:30:00.00Z")
val mike = Person("mike", birthday)
val bits = mike.serialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(Person("mike", birthday))
val bits = mike.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("mike", birthday))
}
@Test
fun `null values`() {
val bob = Person("bob", null)
val bits = bob.serialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(Person("bob", null))
val bits = bob.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("bob", null))
}
@Test
@ -93,10 +85,10 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val noReferencesContext = context.withoutReferences()
val obj : ByteSequence = Ints.toByteArray(0x01234567).sequence()
val originalList : ArrayList<ByteSequence> = ArrayList<ByteSequence>().apply { this += obj }
val deserialisedList = originalList.serialize(factory, noReferencesContext).deserialize(factory, noReferencesContext)
val deserialisedList = originalList.checkpointSerialize(factory, noReferencesContext).checkpointDeserialize(factory, noReferencesContext)
originalList += obj
deserialisedList += obj
assertThat(deserialisedList.serialize(factory, noReferencesContext)).isEqualTo(originalList.serialize(factory, noReferencesContext))
assertThat(deserialisedList.checkpointSerialize(factory, noReferencesContext)).isEqualTo(originalList.checkpointSerialize(factory, noReferencesContext))
}
@Test
@ -113,14 +105,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
this += instant
this += instant
}
assertThat(listWithSameInstances.serialize(factory, noReferencesContext)).isEqualTo(listWithCopies.serialize(factory, noReferencesContext))
assertThat(listWithSameInstances.checkpointSerialize(factory, noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(factory, noReferencesContext))
}
@Test
fun `cyclic object graph`() {
val cyclic = Cyclic(3)
val bits = cyclic.serialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(cyclic)
val bits = cyclic.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(cyclic)
}
@Test
@ -132,7 +124,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
signature.verify(bitsToSign)
assertThatThrownBy { signature.verify(wrongBits) }
val deserialisedKeyPair = keyPair.serialize(factory, context).deserialize(factory, context)
val deserialisedKeyPair = keyPair.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign)
deserialisedSignature.verify(bitsToSign)
assertThatThrownBy { deserialisedSignature.verify(wrongBits) }
@ -140,28 +132,28 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `write and read Kotlin object singleton`() {
val serialised = TestSingleton.serialize(factory, context)
val deserialised = serialised.deserialize(factory, context)
val serialised = TestSingleton.checkpointSerialize(factory, context)
val deserialised = serialised.checkpointDeserialize(factory, context)
assertThat(deserialised).isSameAs(TestSingleton)
}
@Test
fun `check Kotlin EmptyList can be serialised`() {
val deserialisedList: List<Int> = emptyList<Int>().serialize(factory, context).deserialize(factory, context)
val deserialisedList: List<Int> = emptyList<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(0, deserialisedList.size)
assertEquals<Any>(Collections.emptyList<Int>().javaClass, deserialisedList.javaClass)
}
@Test
fun `check Kotlin EmptySet can be serialised`() {
val deserialisedSet: Set<Int> = emptySet<Int>().serialize(factory, context).deserialize(factory, context)
val deserialisedSet: Set<Int> = emptySet<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(0, deserialisedSet.size)
assertEquals<Any>(Collections.emptySet<Int>().javaClass, deserialisedSet.javaClass)
}
@Test
fun `check Kotlin EmptyMap can be serialised`() {
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().serialize(factory, context).deserialize(factory, context)
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(0, deserialisedMap.size)
assertEquals<Any>(Collections.emptyMap<Int, Int>().javaClass, deserialisedMap.javaClass)
}
@ -169,7 +161,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `InputStream serialisation`() {
val rubbish = ByteArray(12345) { (it * it * 0.12345).toByte() }
val readRubbishStream: InputStream = rubbish.inputStream().serialize(factory, context).deserialize(factory, context)
val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
@ -179,7 +171,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `InputStream serialisation does not write trailing garbage`() {
val byteArrays = listOf("123", "456").map { it.toByteArray() }
val streams = byteArrays.map { it.inputStream() }.serialize(factory, context).deserialize(factory, context).iterator()
val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(factory, context).checkpointDeserialize(factory, context).iterator()
byteArrays.forEach { assertArrayEquals(it, streams.next().readBytes()) }
assertFalse(streams.hasNext())
}
@ -190,16 +182,16 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val testBytes = testString.toByteArray()
val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID))
val serializedMetaData = meta.serialize(factory, context).bytes
val meta2 = serializedMetaData.deserialize<SignableData>(factory, context)
val serializedMetaData = meta.checkpointSerialize(factory, context).bytes
val meta2 = serializedMetaData.checkpointDeserialize<SignableData>(factory, context)
assertEquals(meta2, meta)
}
@Test
fun `serialize - deserialize Logger`() {
val storageContext: SerializationContext = context // TODO: make it storage context
val storageContext: CheckpointSerializationContext = context
val logger = LoggerFactory.getLogger("aName")
val logger2 = logger.serialize(factory, storageContext).deserialize(factory, storageContext)
val logger2 = logger.checkpointSerialize(factory, storageContext).checkpointDeserialize(factory, storageContext)
assertEquals(logger.name, logger2.name)
assertTrue(logger === logger2)
}
@ -211,7 +203,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
SecureHash.sha256(rubbish),
rubbish.size,
rubbish.inputStream()
).serialize(factory, context).deserialize(factory, context)
).checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
@ -238,8 +230,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32
))
val serializedBytes = expected.serialize(factory, context)
val actual = serializedBytes.deserialize(factory, context)
val serializedBytes = expected.checkpointSerialize(factory, context)
val actual = serializedBytes.checkpointDeserialize(factory, context)
assertEquals(expected, actual)
}
@ -286,15 +278,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
}
}
Tmp()
val factory = SerializationFactoryImpl().apply { registerScheme(TestScheme()) }
val context = SerializationContextImpl(kryoMagic,
val factory = CheckpointSerializationFactory(KryoSerializationScheme)
val context = CheckpointSerializationContextImpl(
javaClass.classLoader,
AllWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.P2P,
null)
pt.serialize(factory, context)
pt.checkpointSerialize(factory, context)
}
@Test
@ -302,7 +293,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val exception = IllegalArgumentException("fooBar")
val toBeSuppressedOnSenderSide = IllegalStateException("bazz1")
exception.addSuppressed(toBeSuppressedOnSenderSide)
val exception2 = exception.serialize(factory, context).deserialize(factory, context)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(exception.message, exception2.message)
assertEquals(1, exception2.suppressed.size)
@ -317,7 +308,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `serialize - deserialize Exception no suppressed`() {
val exception = IllegalArgumentException("fooBar")
val exception2 = exception.serialize(factory, context).deserialize(factory, context)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(exception.message, exception2.message)
assertEquals(0, exception2.suppressed.size)
@ -331,7 +322,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `serialize - deserialize HashNotFound`() {
val randomHash = SecureHash.randomSHA256()
val exception = FetchDataFlow.HashNotFound(randomHash)
val exception2 = exception.serialize(factory, context).deserialize(factory, context)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(randomHash, exception2.requested)
}
@ -339,17 +330,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `compression has the desired effect`() {
compression ?: return
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
val compressed = data.serialize(factory, context)
val compressed = data.checkpointSerialize(factory, context)
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
assertArrayEquals(data, compressed.deserialize(factory, context))
assertArrayEquals(data, compressed.checkpointDeserialize(factory, context))
}
@Test
fun `a particular encoding can be banned for deserialization`() {
compression ?: return
doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression)
val compressed = "whatever".serialize(factory, context)
catchThrowable { compressed.deserialize(factory, context) }.run {
val compressed = "whatever".checkpointSerialize(factory, context)
catchThrowable { compressed.checkpointDeserialize(factory, context) }.run {
assertSame<Any>(KryoException::class.java, javaClass)
assertEquals(encodingNotPermittedFormat.format(compression), message)
}
@ -360,8 +351,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
class Holder(val holder: ByteArray)
val obj = Holder(ByteArray(20000))
val uncompressedSize = obj.serialize(factory, context.withEncoding(null)).size
val compressedSize = obj.serialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
val uncompressedSize = obj.checkpointSerialize(factory, context.withEncoding(null)).size
val compressedSize = obj.checkpointSerialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
// If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade.
assertEquals(20222, uncompressedSize)
assertEquals(1111, compressedSize)

View File

@ -3,9 +3,9 @@ package net.corda.node.services.persistence
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.node.internal.CheckpointIncompatibleException
import net.corda.node.internal.CheckpointVerifier
import net.corda.node.internal.configureDatabase
@ -189,9 +189,9 @@ class DBCheckpointStorageTests {
val logic: FlowLogic<*> = object : FlowLogic<Unit>() {
override fun call() {}
}
val frozenLogic = logic.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
val frozenLogic = logic.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, ALICE, SubFlowVersion.CoreFlow(version)).getOrThrow()
return id to checkpoint.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
return id to checkpoint.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
}
}