CORDA-2006: Simplify checkpoint serialization (#4042)

* CORDA-2006: Simplify checkpoint serialization

* Supply rule to KryoTest
This commit is contained in:
Dominic Fox
2018-10-08 13:39:28 +01:00
committed by GitHub
parent c88d3d8c1b
commit d9ea19855f
23 changed files with 186 additions and 311 deletions

View File

@ -6,7 +6,6 @@ import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.* import net.corda.core.serialization.SerializationContext.*
import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.serialization.internal.* import net.corda.serialization.internal.*
import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme
@ -35,7 +34,7 @@ class AMQPClientSerializationScheme(
} }
fun createSerializationEnv(classLoader: ClassLoader? = null): SerializationEnvironment { fun createSerializationEnv(classLoader: ClassLoader? = null): SerializationEnvironment {
return SerializationEnvironmentImpl( return SerializationEnvironment.with(
SerializationFactoryImpl().apply { SerializationFactoryImpl().apply {
registerScheme(AMQPClientSerializationScheme(emptyList())) registerScheme(AMQPClientSerializationScheme(emptyList()))
}, },

View File

@ -1,74 +0,0 @@
package net.corda.core.serialization.internal
import net.corda.core.KeepForDJVM
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.ByteSequence
import java.io.NotSerializableException
/**
* A deterministic version of [CheckpointSerializationFactory] that does not use thread-locals to manage serialization
* context.
*/
@KeepForDJVM
class CheckpointSerializationFactory(
private val scheme: CheckpointSerializationScheme
) {
val defaultContext: CheckpointSerializationContext get() = _currentContext ?: effectiveSerializationEnv.checkpointContext
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
/**
* Deserialize the bytes in to an object, using the prefixed bytes to determine the format.
*
* @param byteSequence The bytes to deserialize, including a format header prefix.
* @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown.
* @param context A context that configures various parameters to deserialization.
*/
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) }
}
/**
* Serialize an object to bytes using the preferred serialization format version from the context.
*
* @param obj The object to be serialized.
* @param context A context that configures various parameters to serialization, including the serialization format version.
*/
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return withCurrentContext(context) { scheme.serialize(obj, context) }
}
override fun toString(): String {
return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}"
}
override fun equals(other: Any?): Boolean {
return other is CheckpointSerializationFactory && other.scheme == this.scheme
}
override fun hashCode(): Int = scheme.hashCode()
private var _currentContext: CheckpointSerializationContext? = null
/**
* Change the current context inside the block to that supplied.
*/
fun <T> withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T {
val priorContext = _currentContext
if (context != null) _currentContext = context
try {
return block()
} finally {
if (context != null) _currentContext = priorContext
}
}
companion object {
/**
* A default factory for serialization/deserialization.
*/
val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory
}
}

View File

@ -4,7 +4,7 @@ import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.UseCase.P2P import net.corda.core.serialization.SerializationContext.UseCase.P2P
import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.internal.SerializationEnvironmentImpl import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal._contextSerializationEnv import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.serialization.internal.* import net.corda.serialization.internal.*
import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme
@ -58,13 +58,11 @@ class LocalSerializationRule(private val label: String) : TestRule {
_contextSerializationEnv.set(null) _contextSerializationEnv.set(null)
} }
private fun createTestSerializationEnv(): SerializationEnvironmentImpl { private fun createTestSerializationEnv(): SerializationEnvironment {
val factory = SerializationFactoryImpl(mutableMapOf()).apply { val factory = SerializationFactoryImpl(mutableMapOf()).apply {
registerScheme(AMQPSerializationScheme(emptySet(), AccessOrderLinkedHashMap(128))) registerScheme(AMQPSerializationScheme(emptySet(), AccessOrderLinkedHashMap(128)))
} }
return object : SerializationEnvironmentImpl(factory, AMQP_P2P_CONTEXT) { return SerializationEnvironment.with(factory, AMQP_P2P_CONTEXT)
override fun toString() = "testSerializationEnv($label)"
}
} }
private class AMQPSerializationScheme( private class AMQPSerializationScheme(

View File

@ -13,75 +13,12 @@ import java.io.NotSerializableException
object CheckpointSerializationDefaults { object CheckpointSerializationDefaults {
@DeleteForDJVM @DeleteForDJVM
val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext
val CHECKPOINT_SERIALIZATION_FACTORY get() = effectiveSerializationEnv.checkpointSerializationFactory val CHECKPOINT_SERIALIZER get() = effectiveSerializationEnv.checkpointSerializer
}
/**
* A class for serializing and deserializing objects at checkpoints, using Kryo serialization.
*/
@KeepForDJVM
class CheckpointSerializationFactory(
private val scheme: CheckpointSerializationScheme
) {
val defaultContext: CheckpointSerializationContext get() = _currentContext.get() ?: effectiveSerializationEnv.checkpointContext
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
/**
* Deserialize the bytes in to an object, using the prefixed bytes to determine the format.
*
* @param byteSequence The bytes to deserialize, including a format header prefix.
* @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown.
* @param context A context that configures various parameters to deserialization.
*/
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) }
}
/**
* Serialize an object to bytes using the preferred serialization format version from the context.
*
* @param obj The object to be serialized.
* @param context A context that configures various parameters to serialization, including the serialization format version.
*/
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return withCurrentContext(context) { scheme.serialize(obj, context) }
}
override fun toString(): String {
return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}"
}
override fun equals(other: Any?): Boolean {
return other is CheckpointSerializationFactory && other.scheme == this.scheme
}
override fun hashCode(): Int = scheme.hashCode()
private val _currentContext = ThreadLocal<CheckpointSerializationContext?>()
/**
* Change the current context inside the block to that supplied.
*/
fun <T> withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T {
val priorContext = _currentContext.get()
if (context != null) _currentContext.set(context)
try {
return block()
} finally {
if (context != null) _currentContext.set(priorContext)
}
}
companion object {
val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory
}
} }
@KeepForDJVM @KeepForDJVM
@DoNotImplement @DoNotImplement
interface CheckpointSerializationScheme { interface CheckpointSerializer {
@Throws(NotSerializableException::class) @Throws(NotSerializableException::class)
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T
@ -167,32 +104,36 @@ interface CheckpointSerializationContext {
/* /*
* Convenience extension method for deserializing a ByteSequence, utilising the default factory. * Convenience extension method for deserializing a ByteSequence, utilising the default factory.
*/ */
inline fun <reified T : Any> ByteSequence.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory, @JvmOverloads
context: CheckpointSerializationContext): T { inline fun <reified T : Any> ByteSequence.checkpointDeserialize(
return serializationFactory.deserialize(this, T::class.java, context) context: CheckpointSerializationContext = effectiveSerializationEnv.checkpointContext): T {
return effectiveSerializationEnv.checkpointSerializer.deserialize(this, T::class.java, context)
} }
/** /**
* Convenience extension method for deserializing SerializedBytes with type matching, utilising the default factory. * Convenience extension method for deserializing SerializedBytes with type matching, utilising the default factory.
*/ */
inline fun <reified T : Any> SerializedBytes<T>.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory, @JvmOverloads
context: CheckpointSerializationContext): T { inline fun <reified T : Any> SerializedBytes<T>.checkpointDeserialize(
return serializationFactory.deserialize(this, T::class.java, context) context: CheckpointSerializationContext = effectiveSerializationEnv.checkpointContext): T {
return effectiveSerializationEnv.checkpointSerializer.deserialize(this, T::class.java, context)
} }
/** /**
* Convenience extension method for deserializing a ByteArray, utilising the default factory. * Convenience extension method for deserializing a ByteArray, utilising the default factory.
*/ */
inline fun <reified T : Any> ByteArray.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory, @JvmOverloads
context: CheckpointSerializationContext): T { inline fun <reified T : Any> ByteArray.checkpointDeserialize(
context: CheckpointSerializationContext = effectiveSerializationEnv.checkpointContext): T {
require(isNotEmpty()) { "Empty bytes" } require(isNotEmpty()) { "Empty bytes" }
return this.sequence().checkpointDeserialize(serializationFactory, context) return this.sequence().checkpointDeserialize(context)
} }
/** /**
* Convenience extension method for serializing an object of type T, utilising the default factory. * Convenience extension method for serializing an object of type T, utilising the default factory.
*/ */
fun <T : Any> T.checkpointSerialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory, @JvmOverloads
context: CheckpointSerializationContext): SerializedBytes<T> { fun <T : Any> T.checkpointSerialize(
return serializationFactory.serialize(this, context) context: CheckpointSerializationContext = effectiveSerializationEnv.checkpointContext): SerializedBytes<T> {
return effectiveSerializationEnv.checkpointSerializer.serialize(this, context)
} }

View File

@ -11,38 +11,63 @@ import net.corda.core.serialization.SerializationFactory
@KeepForDJVM @KeepForDJVM
interface SerializationEnvironment { interface SerializationEnvironment {
companion object {
fun with(
serializationFactory: SerializationFactory,
p2pContext: SerializationContext,
rpcServerContext: SerializationContext? = null,
rpcClientContext: SerializationContext? = null,
storageContext: SerializationContext? = null,
checkpointContext: CheckpointSerializationContext? = null,
checkpointSerializer: CheckpointSerializer? = null
): SerializationEnvironment =
SerializationEnvironmentImpl(
serializationFactory = serializationFactory,
p2pContext = p2pContext,
optionalRpcServerContext = rpcServerContext,
optionalRpcClientContext = rpcClientContext,
optionalStorageContext = storageContext,
optionalCheckpointContext = checkpointContext,
optionalCheckpointSerializer = checkpointSerializer
)
}
val serializationFactory: SerializationFactory val serializationFactory: SerializationFactory
val checkpointSerializationFactory: CheckpointSerializationFactory
val p2pContext: SerializationContext val p2pContext: SerializationContext
val rpcServerContext: SerializationContext val rpcServerContext: SerializationContext
val rpcClientContext: SerializationContext val rpcClientContext: SerializationContext
val storageContext: SerializationContext val storageContext: SerializationContext
val checkpointSerializer: CheckpointSerializer
val checkpointContext: CheckpointSerializationContext val checkpointContext: CheckpointSerializationContext
} }
@KeepForDJVM @KeepForDJVM
open class SerializationEnvironmentImpl( private class SerializationEnvironmentImpl(
override val serializationFactory: SerializationFactory, override val serializationFactory: SerializationFactory,
override val p2pContext: SerializationContext, override val p2pContext: SerializationContext,
rpcServerContext: SerializationContext? = null, private val optionalRpcServerContext: SerializationContext? = null,
rpcClientContext: SerializationContext? = null, private val optionalRpcClientContext: SerializationContext? = null,
storageContext: SerializationContext? = null, private val optionalStorageContext: SerializationContext? = null,
checkpointContext: CheckpointSerializationContext? = null, private val optionalCheckpointContext: CheckpointSerializationContext? = null,
checkpointSerializationFactory: CheckpointSerializationFactory? = null) : SerializationEnvironment { private val optionalCheckpointSerializer: CheckpointSerializer? = null) : SerializationEnvironment {
// Those that are passed in as null are never inited:
override lateinit var rpcServerContext: SerializationContext
override lateinit var rpcClientContext: SerializationContext
override lateinit var storageContext: SerializationContext
override lateinit var checkpointContext: CheckpointSerializationContext
override lateinit var checkpointSerializationFactory: CheckpointSerializationFactory
init { override val rpcServerContext: SerializationContext get() = optionalRpcServerContext ?:
rpcServerContext?.let { this.rpcServerContext = it } throw UnsupportedOperationException("RPC server serialization not supported in this environment")
rpcClientContext?.let { this.rpcClientContext = it }
storageContext?.let { this.storageContext = it } override val rpcClientContext: SerializationContext get() = optionalRpcClientContext ?:
checkpointContext?.let { this.checkpointContext = it } throw UnsupportedOperationException("RPC client serialization not supported in this environment")
checkpointSerializationFactory?.let { this.checkpointSerializationFactory = it }
} override val storageContext: SerializationContext get() = optionalStorageContext ?:
throw UnsupportedOperationException("Storage serialization not supported in this environment")
override val checkpointContext: CheckpointSerializationContext get() = optionalCheckpointContext ?:
throw UnsupportedOperationException("Checkpoint serialization not supported in this environment")
override val checkpointSerializer: CheckpointSerializer get() = optionalCheckpointSerializer ?:
throw UnsupportedOperationException("Checkpoint serialization not supported in this environment")
} }
private val _nodeSerializationEnv = SimpleToggleField<SerializationEnvironment>("nodeSerializationEnv", true) private val _nodeSerializationEnv = SimpleToggleField<SerializationEnvironment>("nodeSerializationEnv", true)

View File

@ -1,7 +1,5 @@
package net.corda.core.flows; package net.corda.core.flows;
import net.corda.core.serialization.internal.CheckpointSerializationDefaults;
import net.corda.core.serialization.internal.CheckpointSerializationFactory;
import net.corda.core.serialization.SerializationDefaults; import net.corda.core.serialization.SerializationDefaults;
import net.corda.core.serialization.SerializationFactory; import net.corda.core.serialization.SerializationFactory;
import net.corda.testing.core.SerializationEnvironmentRule; import net.corda.testing.core.SerializationEnvironmentRule;
@ -32,12 +30,10 @@ public class SerializationApiInJavaTest {
SerializationDefaults defaults = SerializationDefaults.INSTANCE; SerializationDefaults defaults = SerializationDefaults.INSTANCE;
SerializationFactory factory = defaults.getSERIALIZATION_FACTORY(); SerializationFactory factory = defaults.getSERIALIZATION_FACTORY();
CheckpointSerializationDefaults checkpointDefaults = CheckpointSerializationDefaults.INSTANCE;
CheckpointSerializationFactory checkpointSerializationFactory = checkpointDefaults.getCHECKPOINT_SERIALIZATION_FACTORY();
serialize("hello", factory, defaults.getP2P_CONTEXT()); serialize("hello", factory, defaults.getP2P_CONTEXT());
serialize("hello", factory, defaults.getRPC_SERVER_CONTEXT()); serialize("hello", factory, defaults.getRPC_SERVER_CONTEXT());
serialize("hello", factory, defaults.getRPC_CLIENT_CONTEXT()); serialize("hello", factory, defaults.getRPC_CLIENT_CONTEXT());
serialize("hello", factory, defaults.getSTORAGE_CONTEXT()); serialize("hello", factory, defaults.getSTORAGE_CONTEXT());
checkpointSerialize("hello", checkpointSerializationFactory, checkpointDefaults.getCHECKPOINT_CONTEXT()); checkpointSerialize("hello");
} }
} }

View File

@ -14,7 +14,7 @@ import net.corda.core.node.services.AttachmentId
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.SerializationEnvironmentImpl import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal._contextSerializationEnv import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
@ -393,7 +393,7 @@ internal constructor(private val initSerEnv: Boolean,
// We need to to set serialization env, because generation of parameters is run from Cordform. // We need to to set serialization env, because generation of parameters is run from Cordform.
private fun initialiseSerialization() { private fun initialiseSerialization() {
_contextSerializationEnv.set(SerializationEnvironmentImpl( _contextSerializationEnv.set(SerializationEnvironment.with(
SerializationFactoryImpl().apply { SerializationFactoryImpl().apply {
registerScheme(AMQPParametersSerializationScheme) registerScheme(AMQPParametersSerializationScheme)
}, },

View File

@ -4,7 +4,6 @@ import net.corda.core.cordapp.Cordapp
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationDefaults import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
@ -21,7 +20,7 @@ object CheckpointVerifier {
*/ */
fun verifyCheckpointsCompatible(checkpointStorage: CheckpointStorage, currentCordapps: List<Cordapp>, platformVersion: Int, serviceHub: ServiceHub, tokenizableServices: List<Any>) { fun verifyCheckpointsCompatible(checkpointStorage: CheckpointStorage, currentCordapps: List<Cordapp>, platformVersion: Int, serviceHub: ServiceHub, tokenizableServices: List<Any>) {
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZER, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
) )
checkpointStorage.getAllCheckpoints().forEach { (_, serializedCheckpoint) -> checkpointStorage.getAllCheckpoints().forEach { (_, serializedCheckpoint) ->

View File

@ -21,8 +21,7 @@ import net.corda.core.messaging.RPCOps
import net.corda.core.node.NetworkParameters import net.corda.core.node.NetworkParameters
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.internal.CheckpointSerializationFactory import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
@ -38,7 +37,7 @@ import net.corda.node.internal.security.RPCSecurityManagerImpl
import net.corda.node.internal.security.RPCSecurityManagerWithAdditionalUser import net.corda.node.internal.security.RPCSecurityManagerWithAdditionalUser
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoSerializationScheme import net.corda.node.serialization.kryo.KryoCheckpointSerializer
import net.corda.node.services.Permissions import net.corda.node.services.Permissions
import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
@ -470,17 +469,19 @@ open class Node(configuration: NodeConfiguration,
private fun initialiseSerialization() { private fun initialiseSerialization() {
if (!initialiseSerialization) return if (!initialiseSerialization) return
val classloader = cordappLoader.appClassLoader val classloader = cordappLoader.appClassLoader
nodeSerializationEnv = SerializationEnvironmentImpl( nodeSerializationEnv = SerializationEnvironment.with(
SerializationFactoryImpl().apply { SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps)) registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps))
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps)) registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps))
}, },
checkpointSerializationFactory = CheckpointSerializationFactory(KryoSerializationScheme),
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader), rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),
rpcClientContext = if (configuration.shouldInitCrashShell()) AMQP_RPC_CLIENT_CONTEXT.withClassLoader(classloader) else null, //even Shell embeded in the node connects via RPC to the node
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader), storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader),
rpcClientContext = if (configuration.shouldInitCrashShell()) AMQP_RPC_CLIENT_CONTEXT.withClassLoader(classloader) else null) //even Shell embeded in the node connects via RPC to the node checkpointSerializer = KryoCheckpointSerializer,
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader)
)
} }
/** Starts a blocking event loop for message dispatch. */ /** Starts a blocking event loop for message dispatch. */

View File

@ -12,10 +12,9 @@ import com.esotericsoftware.kryo.serializers.ClosureSerializer
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationScheme import net.corda.core.serialization.internal.CheckpointSerializer
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.serialization.internal.* import net.corda.serialization.internal.*
import java.security.PublicKey
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0)) val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0))
@ -31,7 +30,7 @@ private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>()
override fun read(kryo: Kryo, input: Input, type: Class<AutoCloseable>) = throw IllegalStateException("Should not reach here!") override fun read(kryo: Kryo, input: Input, type: Class<AutoCloseable>) = throw IllegalStateException("Should not reach here!")
} }
object KryoSerializationScheme : CheckpointSerializationScheme { object KryoCheckpointSerializer : CheckpointSerializer {
private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>() private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>()
private fun getPool(context: CheckpointSerializationContext): KryoPool { private fun getPool(context: CheckpointSerializationContext): KryoPool {

View File

@ -127,7 +127,7 @@ class SingleThreadedStateMachineManager(
override fun start(tokenizableServices: List<Any>) { override fun start(tokenizableServices: List<Any>) {
checkQuasarJavaAgentPresence() checkQuasarJavaAgentPresence()
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZER, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
) )
this.checkpointSerializationContext = checkpointSerializationContext this.checkpointSerializationContext = checkpointSerializationContext
this.actionExecutor = makeActionExecutor(checkpointSerializationContext) this.actionExecutor = makeActionExecutor(checkpointSerializationContext)

View File

@ -22,7 +22,7 @@ import net.corda.core.internal.notary.isConsumedByTheSameTx
import net.corda.core.internal.notary.validateTimeWindow import net.corda.core.internal.notary.validateTimeWindow
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationDefaults import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
@ -201,7 +201,7 @@ class RaftTransactionCommitLog<E, EK>(
class CordaKryoSerializer<T : Any> : TypeSerializer<T> { class CordaKryoSerializer<T : Any> : TypeSerializer<T> {
private val context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY) private val context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY)
private val factory = CheckpointSerializationFactory.defaultFactory private val checkpointSerializer = CheckpointSerializationDefaults.CHECKPOINT_SERIALIZER
override fun write(obj: T, buffer: BufferOutput<*>, serializer: Serializer) { override fun write(obj: T, buffer: BufferOutput<*>, serializer: Serializer) {
val serialized = obj.checkpointSerialize(context = context) val serialized = obj.checkpointSerialize(context = context)
@ -213,7 +213,7 @@ class RaftTransactionCommitLog<E, EK>(
val size = buffer.readInt() val size = buffer.readInt()
val serialized = ByteArray(size) val serialized = ByteArray(size)
buffer.read(serialized) buffer.read(serialized)
return factory.deserialize(ByteSequence.of(serialized), type, context) return checkpointSerializer.deserialize(ByteSequence.of(serialized), type, context)
} }
} }
} }

View File

@ -13,7 +13,6 @@ import net.corda.core.crypto.*
import net.corda.core.internal.FetchDataFlow import net.corda.core.internal.FetchDataFlow
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
@ -23,11 +22,13 @@ import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.serialization.internal.* import net.corda.serialization.internal.*
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.TestIdentity import net.corda.testing.core.TestIdentity
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.rigorousMock
import org.assertj.core.api.Assertions.* import org.assertj.core.api.Assertions.*
import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import org.junit.Before import org.junit.Before
import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.junit.runners.Parameterized import org.junit.runners.Parameterized
@ -48,12 +49,12 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values() fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
} }
private lateinit var factory: CheckpointSerializationFactory @get:Rule
val serializationRule = CheckpointSerializationEnvironmentRule()
private lateinit var context: CheckpointSerializationContext private lateinit var context: CheckpointSerializationContext
@Before @Before
fun setup() { fun setup() {
factory = CheckpointSerializationFactory(KryoSerializationScheme)
context = CheckpointSerializationContextImpl( context = CheckpointSerializationContextImpl(
javaClass.classLoader, javaClass.classLoader,
AllWhitelist, AllWhitelist,
@ -69,15 +70,15 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `simple data class`() { fun `simple data class`() {
val birthday = Instant.parse("1984-04-17T00:30:00.00Z") val birthday = Instant.parse("1984-04-17T00:30:00.00Z")
val mike = Person("mike", birthday) val mike = Person("mike", birthday)
val bits = mike.checkpointSerialize(factory, context) val bits = mike.checkpointSerialize(context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("mike", birthday)) assertThat(bits.checkpointDeserialize(context)).isEqualTo(Person("mike", birthday))
} }
@Test @Test
fun `null values`() { fun `null values`() {
val bob = Person("bob", null) val bob = Person("bob", null)
val bits = bob.checkpointSerialize(factory, context) val bits = bob.checkpointSerialize(context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("bob", null)) assertThat(bits.checkpointDeserialize(context)).isEqualTo(Person("bob", null))
} }
@Test @Test
@ -85,10 +86,10 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val noReferencesContext = context.withoutReferences() val noReferencesContext = context.withoutReferences()
val obj : ByteSequence = Ints.toByteArray(0x01234567).sequence() val obj : ByteSequence = Ints.toByteArray(0x01234567).sequence()
val originalList : ArrayList<ByteSequence> = ArrayList<ByteSequence>().apply { this += obj } val originalList : ArrayList<ByteSequence> = ArrayList<ByteSequence>().apply { this += obj }
val deserialisedList = originalList.checkpointSerialize(factory, noReferencesContext).checkpointDeserialize(factory, noReferencesContext) val deserialisedList = originalList.checkpointSerialize(noReferencesContext).checkpointDeserialize(noReferencesContext)
originalList += obj originalList += obj
deserialisedList += obj deserialisedList += obj
assertThat(deserialisedList.checkpointSerialize(factory, noReferencesContext)).isEqualTo(originalList.checkpointSerialize(factory, noReferencesContext)) assertThat(deserialisedList.checkpointSerialize(noReferencesContext)).isEqualTo(originalList.checkpointSerialize(noReferencesContext))
} }
@Test @Test
@ -105,14 +106,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
this += instant this += instant
this += instant this += instant
} }
assertThat(listWithSameInstances.checkpointSerialize(factory, noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(factory, noReferencesContext)) assertThat(listWithSameInstances.checkpointSerialize(noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(noReferencesContext))
} }
@Test @Test
fun `cyclic object graph`() { fun `cyclic object graph`() {
val cyclic = Cyclic(3) val cyclic = Cyclic(3)
val bits = cyclic.checkpointSerialize(factory, context) val bits = cyclic.checkpointSerialize(context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(cyclic) assertThat(bits.checkpointDeserialize(context)).isEqualTo(cyclic)
} }
@Test @Test
@ -124,7 +125,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
signature.verify(bitsToSign) signature.verify(bitsToSign)
assertThatThrownBy { signature.verify(wrongBits) } assertThatThrownBy { signature.verify(wrongBits) }
val deserialisedKeyPair = keyPair.checkpointSerialize(factory, context).checkpointDeserialize(factory, context) val deserialisedKeyPair = keyPair.checkpointSerialize(context).checkpointDeserialize(context)
val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign) val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign)
deserialisedSignature.verify(bitsToSign) deserialisedSignature.verify(bitsToSign)
assertThatThrownBy { deserialisedSignature.verify(wrongBits) } assertThatThrownBy { deserialisedSignature.verify(wrongBits) }
@ -132,28 +133,28 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test @Test
fun `write and read Kotlin object singleton`() { fun `write and read Kotlin object singleton`() {
val serialised = TestSingleton.checkpointSerialize(factory, context) val serialised = TestSingleton.checkpointSerialize(context)
val deserialised = serialised.checkpointDeserialize(factory, context) val deserialised = serialised.checkpointDeserialize(context)
assertThat(deserialised).isSameAs(TestSingleton) assertThat(deserialised).isSameAs(TestSingleton)
} }
@Test @Test
fun `check Kotlin EmptyList can be serialised`() { fun `check Kotlin EmptyList can be serialised`() {
val deserialisedList: List<Int> = emptyList<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context) val deserialisedList: List<Int> = emptyList<Int>().checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(0, deserialisedList.size) assertEquals(0, deserialisedList.size)
assertEquals<Any>(Collections.emptyList<Int>().javaClass, deserialisedList.javaClass) assertEquals<Any>(Collections.emptyList<Int>().javaClass, deserialisedList.javaClass)
} }
@Test @Test
fun `check Kotlin EmptySet can be serialised`() { fun `check Kotlin EmptySet can be serialised`() {
val deserialisedSet: Set<Int> = emptySet<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context) val deserialisedSet: Set<Int> = emptySet<Int>().checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(0, deserialisedSet.size) assertEquals(0, deserialisedSet.size)
assertEquals<Any>(Collections.emptySet<Int>().javaClass, deserialisedSet.javaClass) assertEquals<Any>(Collections.emptySet<Int>().javaClass, deserialisedSet.javaClass)
} }
@Test @Test
fun `check Kotlin EmptyMap can be serialised`() { fun `check Kotlin EmptyMap can be serialised`() {
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context) val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(0, deserialisedMap.size) assertEquals(0, deserialisedMap.size)
assertEquals<Any>(Collections.emptyMap<Int, Int>().javaClass, deserialisedMap.javaClass) assertEquals<Any>(Collections.emptyMap<Int, Int>().javaClass, deserialisedMap.javaClass)
} }
@ -161,7 +162,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test @Test
fun `InputStream serialisation`() { fun `InputStream serialisation`() {
val rubbish = ByteArray(12345) { (it * it * 0.12345).toByte() } val rubbish = ByteArray(12345) { (it * it * 0.12345).toByte() }
val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(factory, context).checkpointDeserialize(factory, context) val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(context).checkpointDeserialize(context)
for (i in 0..12344) { for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte()) assertEquals(rubbish[i], readRubbishStream.read().toByte())
} }
@ -171,7 +172,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test @Test
fun `InputStream serialisation does not write trailing garbage`() { fun `InputStream serialisation does not write trailing garbage`() {
val byteArrays = listOf("123", "456").map { it.toByteArray() } val byteArrays = listOf("123", "456").map { it.toByteArray() }
val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(factory, context).checkpointDeserialize(factory, context).iterator() val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(context).checkpointDeserialize(context).iterator()
byteArrays.forEach { assertArrayEquals(it, streams.next().readBytes()) } byteArrays.forEach { assertArrayEquals(it, streams.next().readBytes()) }
assertFalse(streams.hasNext()) assertFalse(streams.hasNext())
} }
@ -182,8 +183,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val testBytes = testString.toByteArray() val testBytes = testString.toByteArray()
val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID)) val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID))
val serializedMetaData = meta.checkpointSerialize(factory, context).bytes val serializedMetaData = meta.checkpointSerialize(context).bytes
val meta2 = serializedMetaData.checkpointDeserialize<SignableData>(factory, context) val meta2 = serializedMetaData.checkpointDeserialize<SignableData>(context)
assertEquals(meta2, meta) assertEquals(meta2, meta)
} }
@ -191,7 +192,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `serialize - deserialize Logger`() { fun `serialize - deserialize Logger`() {
val storageContext: CheckpointSerializationContext = context val storageContext: CheckpointSerializationContext = context
val logger = LoggerFactory.getLogger("aName") val logger = LoggerFactory.getLogger("aName")
val logger2 = logger.checkpointSerialize(factory, storageContext).checkpointDeserialize(factory, storageContext) val logger2 = logger.checkpointSerialize(storageContext).checkpointDeserialize(storageContext)
assertEquals(logger.name, logger2.name) assertEquals(logger.name, logger2.name)
assertTrue(logger === logger2) assertTrue(logger === logger2)
} }
@ -203,7 +204,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
SecureHash.sha256(rubbish), SecureHash.sha256(rubbish),
rubbish.size, rubbish.size,
rubbish.inputStream() rubbish.inputStream()
).checkpointSerialize(factory, context).checkpointDeserialize(factory, context) ).checkpointSerialize(context).checkpointDeserialize(context)
for (i in 0..12344) { for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte()) assertEquals(rubbish[i], readRubbishStream.read().toByte())
} }
@ -230,8 +231,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32 31, 32
)) ))
val serializedBytes = expected.checkpointSerialize(factory, context) val serializedBytes = expected.checkpointSerialize(context)
val actual = serializedBytes.checkpointDeserialize(factory, context) val actual = serializedBytes.checkpointDeserialize(context)
assertEquals(expected, actual) assertEquals(expected, actual)
} }
@ -278,14 +279,13 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
} }
} }
Tmp() Tmp()
val factory = CheckpointSerializationFactory(KryoSerializationScheme)
val context = CheckpointSerializationContextImpl( val context = CheckpointSerializationContextImpl(
javaClass.classLoader, javaClass.classLoader,
AllWhitelist, AllWhitelist,
emptyMap(), emptyMap(),
true, true,
null) null)
pt.checkpointSerialize(factory, context) pt.checkpointSerialize(context)
} }
@Test @Test
@ -293,7 +293,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val exception = IllegalArgumentException("fooBar") val exception = IllegalArgumentException("fooBar")
val toBeSuppressedOnSenderSide = IllegalStateException("bazz1") val toBeSuppressedOnSenderSide = IllegalStateException("bazz1")
exception.addSuppressed(toBeSuppressedOnSenderSide) exception.addSuppressed(toBeSuppressedOnSenderSide)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context) val exception2 = exception.checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(exception.message, exception2.message) assertEquals(exception.message, exception2.message)
assertEquals(1, exception2.suppressed.size) assertEquals(1, exception2.suppressed.size)
@ -308,7 +308,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test @Test
fun `serialize - deserialize Exception no suppressed`() { fun `serialize - deserialize Exception no suppressed`() {
val exception = IllegalArgumentException("fooBar") val exception = IllegalArgumentException("fooBar")
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context) val exception2 = exception.checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(exception.message, exception2.message) assertEquals(exception.message, exception2.message)
assertEquals(0, exception2.suppressed.size) assertEquals(0, exception2.suppressed.size)
@ -322,7 +322,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `serialize - deserialize HashNotFound`() { fun `serialize - deserialize HashNotFound`() {
val randomHash = SecureHash.randomSHA256() val randomHash = SecureHash.randomSHA256()
val exception = FetchDataFlow.HashNotFound(randomHash) val exception = FetchDataFlow.HashNotFound(randomHash)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context) val exception2 = exception.checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(randomHash, exception2.requested) assertEquals(randomHash, exception2.requested)
} }
@ -330,17 +330,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `compression has the desired effect`() { fun `compression has the desired effect`() {
compression ?: return compression ?: return
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it } val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
val compressed = data.checkpointSerialize(factory, context) val compressed = data.checkpointSerialize(context)
assertEquals(.5, compressed.size.toDouble() / data.size, .03) assertEquals(.5, compressed.size.toDouble() / data.size, .03)
assertArrayEquals(data, compressed.checkpointDeserialize(factory, context)) assertArrayEquals(data, compressed.checkpointDeserialize(context))
} }
@Test @Test
fun `a particular encoding can be banned for deserialization`() { fun `a particular encoding can be banned for deserialization`() {
compression ?: return compression ?: return
doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression) doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression)
val compressed = "whatever".checkpointSerialize(factory, context) val compressed = "whatever".checkpointSerialize(context)
catchThrowable { compressed.checkpointDeserialize(factory, context) }.run { catchThrowable { compressed.checkpointDeserialize(context) }.run {
assertSame<Any>(KryoException::class.java, javaClass) assertSame<Any>(KryoException::class.java, javaClass)
assertEquals(encodingNotPermittedFormat.format(compression), message) assertEquals(encodingNotPermittedFormat.format(compression), message)
} }
@ -351,8 +351,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
class Holder(val holder: ByteArray) class Holder(val holder: ByteArray)
val obj = Holder(ByteArray(20000)) val obj = Holder(ByteArray(20000))
val uncompressedSize = obj.checkpointSerialize(factory, context.withEncoding(null)).size val uncompressedSize = obj.checkpointSerialize(context.withEncoding(null)).size
val compressedSize = obj.checkpointSerialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size val compressedSize = obj.checkpointSerialize(context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
// If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade. // If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade.
assertEquals(20222, uncompressedSize) assertEquals(20222, uncompressedSize)
assertEquals(1111, compressedSize) assertEquals(1111, compressedSize)

View File

@ -5,7 +5,7 @@ import net.corda.core.DeleteForDJVM
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory import net.corda.core.serialization.internal.CheckpointSerializer
val serializationContextKey = SerializeAsTokenContext::class.java val serializationContextKey = SerializeAsTokenContext::class.java
@ -70,8 +70,8 @@ class SerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: Ser
*/ */
@DeleteForDJVM @DeleteForDJVM
class CheckpointSerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: SerializeAsTokenContext.() -> Unit) : SerializeAsTokenContext { class CheckpointSerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: SerializeAsTokenContext.() -> Unit) : SerializeAsTokenContext {
constructor(toBeTokenized: Any, serializationFactory: CheckpointSerializationFactory, context: CheckpointSerializationContext, serviceHub: ServiceHub) : this(serviceHub, { constructor(toBeTokenized: Any, serializer: CheckpointSerializer, context: CheckpointSerializationContext, serviceHub: ServiceHub) : this(serviceHub, {
serializationFactory.serialize(toBeTokenized, context.withTokenContext(this)) serializer.serialize(toBeTokenized, context.withTokenContext(this))
}) })
private val classNameToSingleton = mutableMapOf<String, SerializeAsToken>() private val classNameToSingleton = mutableMapOf<String, SerializeAsToken>()

View File

@ -2,14 +2,14 @@ package net.corda.serialization.internal;
import net.corda.core.serialization.*; import net.corda.core.serialization.*;
import net.corda.core.serialization.internal.CheckpointSerializationContext; import net.corda.core.serialization.internal.CheckpointSerializationContext;
import net.corda.core.serialization.internal.CheckpointSerializationFactory; import net.corda.core.serialization.internal.CheckpointSerializer;
import net.corda.node.serialization.kryo.CordaClosureSerializer; import net.corda.node.serialization.kryo.CordaClosureSerializer;
import net.corda.testing.core.SerializationEnvironmentRule;
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule; import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import java.io.NotSerializableException;
import java.io.Serializable; import java.io.Serializable;
import java.util.Collections; import java.util.Collections;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
@ -23,12 +23,11 @@ public final class LambdaCheckpointSerializationTest {
public final CheckpointSerializationEnvironmentRule testCheckpointSerialization = public final CheckpointSerializationEnvironmentRule testCheckpointSerialization =
new CheckpointSerializationEnvironmentRule(); new CheckpointSerializationEnvironmentRule();
private CheckpointSerializationFactory factory;
private CheckpointSerializationContext context; private CheckpointSerializationContext context;
private CheckpointSerializer serializer;
@Before @Before
public void setup() { public void setup() {
factory = testCheckpointSerialization.getCheckpointSerializationFactory();
context = new CheckpointSerializationContextImpl( context = new CheckpointSerializationContextImpl(
getClass().getClassLoader(), getClass().getClassLoader(),
AllWhitelist.INSTANCE, AllWhitelist.INSTANCE,
@ -36,6 +35,8 @@ public final class LambdaCheckpointSerializationTest {
true, true,
null null
); );
serializer = testCheckpointSerialization.getCheckpointSerializer();
} }
@Test @Test
@ -63,11 +64,11 @@ public final class LambdaCheckpointSerializationTest {
assertThat(throwable).hasMessage(CordaClosureSerializer.ERROR_MESSAGE); assertThat(throwable).hasMessage(CordaClosureSerializer.ERROR_MESSAGE);
} }
private <T> SerializedBytes<T> serialize(final T target) { private <T> SerializedBytes<T> serialize(final T target) throws NotSerializableException {
return factory.serialize(target, context); return serializer.serialize(target, context);
} }
private <T> T deserialize(final SerializedBytes<? extends T> bytes, final Class<T> type) { private <T> T deserialize(final SerializedBytes<? extends T> bytes, final Class<T> type) throws NotSerializableException {
return factory.deserialize(bytes, type, context); return serializer.deserialize(bytes, type, context);
} }
} }

View File

@ -4,11 +4,9 @@ import net.corda.core.contracts.ContractAttachment
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices
@ -27,24 +25,25 @@ class ContractAttachmentSerializerTest {
@JvmField @JvmField
val testCheckpointSerialization = CheckpointSerializationEnvironmentRule() val testCheckpointSerialization = CheckpointSerializationEnvironmentRule()
private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: CheckpointSerializationContext
private lateinit var contextWithToken: CheckpointSerializationContext private lateinit var contextWithToken: CheckpointSerializationContext
private val mockServices = MockServices(emptyList(), CordaX500Name("MegaCorp", "London", "GB"), rigorousMock()) private val mockServices = MockServices(emptyList(), CordaX500Name("MegaCorp", "London", "GB"), rigorousMock())
@Before @Before
fun setup() { fun setup() {
factory = testCheckpointSerialization.checkpointSerializationFactory contextWithToken = testCheckpointSerialization.checkpointSerializationContext.withTokenContext(
context = testCheckpointSerialization.checkpointSerializationContext CheckpointSerializeAsTokenContextImpl(
contextWithToken = context.withTokenContext(CheckpointSerializeAsTokenContextImpl(Any(), factory, context, mockServices)) Any(),
testCheckpointSerialization.checkpointSerializer,
testCheckpointSerialization.checkpointSerializationContext,
mockServices))
} }
@Test @Test
fun `write contract attachment and read it back`() { fun `write contract attachment and read it back`() {
val contractAttachment = ContractAttachment(GeneratedAttachment(EMPTY_BYTE_ARRAY), DummyContract.PROGRAM_ID) val contractAttachment = ContractAttachment(GeneratedAttachment(EMPTY_BYTE_ARRAY), DummyContract.PROGRAM_ID)
// no token context so will serialize the whole attachment // no token context so will serialize the whole attachment
val serialized = contractAttachment.checkpointSerialize(factory, context) val serialized = contractAttachment.checkpointSerialize()
val deserialized = serialized.checkpointDeserialize(factory, context) val deserialized = serialized.checkpointDeserialize()
assertEquals(contractAttachment.id, deserialized.attachment.id) assertEquals(contractAttachment.id, deserialized.attachment.id)
assertEquals(contractAttachment.contract, deserialized.contract) assertEquals(contractAttachment.contract, deserialized.contract)
@ -59,8 +58,8 @@ class ContractAttachmentSerializerTest {
mockServices.attachments.importAttachment(attachment.open(), "test", null) mockServices.attachments.importAttachment(attachment.open(), "test", null)
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken) val serialized = contractAttachment.checkpointSerialize(contextWithToken)
val deserialized = serialized.checkpointDeserialize(factory, contextWithToken) val deserialized = serialized.checkpointDeserialize(contextWithToken)
assertEquals(contractAttachment.id, deserialized.attachment.id) assertEquals(contractAttachment.id, deserialized.attachment.id)
assertEquals(contractAttachment.contract, deserialized.contract) assertEquals(contractAttachment.contract, deserialized.contract)
@ -76,7 +75,7 @@ class ContractAttachmentSerializerTest {
mockServices.attachments.importAttachment(attachment.open(), "test", null) mockServices.attachments.importAttachment(attachment.open(), "test", null)
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken) val serialized = contractAttachment.checkpointSerialize(contextWithToken)
assertThat(serialized.size).isLessThan(largeAttachmentSize) assertThat(serialized.size).isLessThan(largeAttachmentSize)
} }
@ -88,8 +87,8 @@ class ContractAttachmentSerializerTest {
// don't importAttachment in mockService // don't importAttachment in mockService
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken) val serialized = contractAttachment.checkpointSerialize(contextWithToken)
val deserialized = serialized.checkpointDeserialize(factory, contextWithToken) val deserialized = serialized.checkpointDeserialize(contextWithToken)
assertThatThrownBy { deserialized.attachment.open() }.isInstanceOf(MissingAttachmentsException::class.java) assertThatThrownBy { deserialized.attachment.open() }.isInstanceOf(MissingAttachmentsException::class.java)
} }
@ -100,8 +99,8 @@ class ContractAttachmentSerializerTest {
// don't importAttachment in mockService // don't importAttachment in mockService
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken) val serialized = contractAttachment.checkpointSerialize(contextWithToken)
serialized.checkpointDeserialize(factory, contextWithToken) serialized.checkpointDeserialize(contextWithToken)
// MissingAttachmentsException thrown if we try to open attachment // MissingAttachmentsException thrown if we try to open attachment
} }

View File

@ -5,7 +5,6 @@ import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
@ -14,7 +13,6 @@ import net.corda.node.serialization.kryo.CordaKryo
import net.corda.node.serialization.kryo.DefaultKryoCustomizer import net.corda.node.serialization.kryo.DefaultKryoCustomizer
import net.corda.node.serialization.kryo.kryoMagic import net.corda.node.serialization.kryo.kryoMagic
import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.rigorousMock
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Before import org.junit.Before
@ -28,12 +26,10 @@ class SerializationTokenTest {
@JvmField @JvmField
val testCheckpointSerialization = CheckpointSerializationEnvironmentRule() val testCheckpointSerialization = CheckpointSerializationEnvironmentRule()
private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: CheckpointSerializationContext private lateinit var context: CheckpointSerializationContext
@Before @Before
fun setup() { fun setup() {
factory = testCheckpointSerialization.checkpointSerializationFactory
context = testCheckpointSerialization.checkpointSerializationContext.withWhitelisted(SingletonSerializationToken::class.java) context = testCheckpointSerialization.checkpointSerializationContext.withWhitelisted(SingletonSerializationToken::class.java)
} }
@ -49,16 +45,16 @@ class SerializationTokenTest {
override fun equals(other: Any?) = other is LargeTokenizable && other.bytes.size == this.bytes.size override fun equals(other: Any?) = other is LargeTokenizable && other.bytes.size == this.bytes.size
} }
private fun serializeAsTokenContext(toBeTokenized: Any) = CheckpointSerializeAsTokenContextImpl(toBeTokenized, factory, context, rigorousMock()) private fun serializeAsTokenContext(toBeTokenized: Any) = CheckpointSerializeAsTokenContextImpl(toBeTokenized, testCheckpointSerialization.checkpointSerializer, context, rigorousMock())
@Test @Test
fun `write token and read tokenizable`() { fun `write token and read tokenizable`() {
val tokenizableBefore = LargeTokenizable() val tokenizableBefore = LargeTokenizable()
val context = serializeAsTokenContext(tokenizableBefore) val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context) val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext) val serializedBytes = tokenizableBefore.checkpointSerialize(testContext)
assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes) assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes)
val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext) val tokenizableAfter = serializedBytes.checkpointDeserialize(testContext)
assertThat(tokenizableAfter).isSameAs(tokenizableBefore) assertThat(tokenizableAfter).isSameAs(tokenizableBefore)
} }
@ -69,8 +65,8 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken() val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore) val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context) val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext) val serializedBytes = tokenizableBefore.checkpointSerialize(testContext)
val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext) val tokenizableAfter = serializedBytes.checkpointDeserialize(testContext)
assertThat(tokenizableAfter).isSameAs(tokenizableBefore) assertThat(tokenizableAfter).isSameAs(tokenizableBefore)
} }
@ -79,7 +75,7 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken() val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(emptyList<Any>()) val context = serializeAsTokenContext(emptyList<Any>())
val testContext = this.context.withTokenContext(context) val testContext = this.context.withTokenContext(context)
tokenizableBefore.checkpointSerialize(factory, testContext) tokenizableBefore.checkpointSerialize(testContext)
} }
@Test(expected = UnsupportedOperationException::class) @Test(expected = UnsupportedOperationException::class)
@ -87,14 +83,14 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken() val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(emptyList<Any>()) val context = serializeAsTokenContext(emptyList<Any>())
val testContext = this.context.withTokenContext(context) val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList<Any>())).checkpointSerialize(factory, testContext) val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList<Any>())).checkpointSerialize(testContext)
serializedBytes.checkpointDeserialize(factory, testContext) serializedBytes.checkpointDeserialize(testContext)
} }
@Test(expected = KryoException::class) @Test(expected = KryoException::class)
fun `no context set`() { fun `no context set`() {
val tokenizableBefore = UnitSerializeAsToken() val tokenizableBefore = UnitSerializeAsToken()
tokenizableBefore.checkpointSerialize(factory, context) tokenizableBefore.checkpointSerialize(context)
} }
@Test(expected = KryoException::class) @Test(expected = KryoException::class)
@ -112,7 +108,7 @@ class SerializationTokenTest {
kryo.writeObject(it, emptyList<Any>()) kryo.writeObject(it, emptyList<Any>())
} }
val serializedBytes = SerializedBytes<Any>(stream.toByteArray()) val serializedBytes = SerializedBytes<Any>(stream.toByteArray())
serializedBytes.checkpointDeserialize(factory, testContext) serializedBytes.checkpointDeserialize(testContext)
} }
private class WrongTypeSerializeAsToken : SerializeAsToken { private class WrongTypeSerializeAsToken : SerializeAsToken {
@ -128,7 +124,7 @@ class SerializationTokenTest {
val tokenizableBefore = WrongTypeSerializeAsToken() val tokenizableBefore = WrongTypeSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore) val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context) val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext) val serializedBytes = tokenizableBefore.checkpointSerialize(testContext)
serializedBytes.checkpointDeserialize(factory, testContext) serializedBytes.checkpointDeserialize(testContext)
} }
} }

View File

@ -3,9 +3,7 @@ package net.corda.testing.core
import com.nhaarman.mockito_kotlin.any import com.nhaarman.mockito_kotlin.any
import com.nhaarman.mockito_kotlin.doAnswer import com.nhaarman.mockito_kotlin.doAnswer
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.DoNotImplement
import net.corda.core.internal.staticField import net.corda.core.internal.staticField
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.effectiveSerializationEnv import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.testing.common.internal.asContextEnv import net.corda.testing.common.internal.asContextEnv
@ -40,7 +38,7 @@ class SerializationEnvironmentRule(private val inheritable: Boolean = false) : T
/** Do not call, instead use [SerializationEnvironmentRule] as a [org.junit.Rule]. */ /** Do not call, instead use [SerializationEnvironmentRule] as a [org.junit.Rule]. */
fun <T> run(taskLabel: String, task: (SerializationEnvironment) -> T): T { fun <T> run(taskLabel: String, task: (SerializationEnvironment) -> T): T {
return SerializationEnvironmentRule().apply { init(taskLabel) }.runTask(task) return SerializationEnvironmentRule().apply { init() }.runTask(task)
} }
} }
@ -48,14 +46,14 @@ class SerializationEnvironmentRule(private val inheritable: Boolean = false) : T
val serializationFactory get() = env.serializationFactory val serializationFactory get() = env.serializationFactory
override fun apply(base: Statement, description: Description): Statement { override fun apply(base: Statement, description: Description): Statement {
init(description.toString()) init()
return object : Statement() { return object : Statement() {
override fun evaluate() = runTask { base.evaluate() } override fun evaluate() = runTask { base.evaluate() }
} }
} }
private fun init(envLabel: String) { private fun init() {
env = createTestSerializationEnv(envLabel) env = createTestSerializationEnv()
} }
private fun <T> runTask(task: (SerializationEnvironment) -> T): T { private fun <T> runTask(task: (SerializationEnvironment) -> T): T {

View File

@ -39,7 +39,7 @@ class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean =
/** Do not call, instead use [SerializationEnvironmentRule] as a [org.junit.Rule]. */ /** Do not call, instead use [SerializationEnvironmentRule] as a [org.junit.Rule]. */
fun <T> run(taskLabel: String, task: (SerializationEnvironment) -> T): T { fun <T> run(taskLabel: String, task: (SerializationEnvironment) -> T): T {
return CheckpointSerializationEnvironmentRule().apply { init(taskLabel) }.runTask(task) return CheckpointSerializationEnvironmentRule().apply { init() }.runTask(task)
} }
} }
@ -47,14 +47,14 @@ class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean =
private lateinit var env: SerializationEnvironment private lateinit var env: SerializationEnvironment
override fun apply(base: Statement, description: Description): Statement { override fun apply(base: Statement, description: Description): Statement {
init(description.toString()) init()
return object : Statement() { return object : Statement() {
override fun evaluate() = runTask { base.evaluate() } override fun evaluate() = runTask { base.evaluate() }
} }
} }
private fun init(envLabel: String) { private fun init() {
env = createTestSerializationEnv(envLabel) env = createTestSerializationEnv()
} }
private fun <T> runTask(task: (SerializationEnvironment) -> T): T { private fun <T> runTask(task: (SerializationEnvironment) -> T): T {
@ -65,7 +65,6 @@ class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean =
} }
} }
val checkpointSerializationFactory get() = env.checkpointSerializationFactory
val checkpointSerializationContext get() = env.checkpointContext val checkpointSerializationContext get() = env.checkpointContext
val checkpointSerializer get() = env.checkpointSerializer
} }

View File

@ -4,11 +4,10 @@ import com.nhaarman.mockito_kotlin.doNothing
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme
import net.corda.core.DoNotImplement import net.corda.core.DoNotImplement
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.* import net.corda.core.serialization.internal.*
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoSerializationScheme import net.corda.node.serialization.kryo.KryoCheckpointSerializer
import net.corda.serialization.internal.* import net.corda.serialization.internal.*
import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.SerializationEnvironmentRule
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
@ -30,22 +29,20 @@ fun <T> withoutTestSerialization(callable: () -> T): T { // TODO: Delete this, s
} }
} }
internal fun createTestSerializationEnv(label: String): SerializationEnvironmentImpl { internal fun createTestSerializationEnv(): SerializationEnvironment {
val factory = SerializationFactoryImpl().apply { val factory = SerializationFactoryImpl().apply {
registerScheme(AMQPClientSerializationScheme(emptyList())) registerScheme(AMQPClientSerializationScheme(emptyList()))
registerScheme(AMQPServerSerializationScheme(emptyList())) registerScheme(AMQPServerSerializationScheme(emptyList()))
} }
return object : SerializationEnvironmentImpl( return SerializationEnvironment.with(
factory, factory,
AMQP_P2P_CONTEXT, AMQP_P2P_CONTEXT,
AMQP_RPC_SERVER_CONTEXT, AMQP_RPC_SERVER_CONTEXT,
AMQP_RPC_CLIENT_CONTEXT, AMQP_RPC_CLIENT_CONTEXT,
AMQP_STORAGE_CONTEXT, AMQP_STORAGE_CONTEXT,
KRYO_CHECKPOINT_CONTEXT, KRYO_CHECKPOINT_CONTEXT,
CheckpointSerializationFactory(KryoSerializationScheme) KryoCheckpointSerializer
) { )
override fun toString() = "testSerializationEnv($label)"
}
} }
/** /**
@ -54,7 +51,7 @@ internal fun createTestSerializationEnv(label: String): SerializationEnvironment
*/ */
fun setGlobalSerialization(armed: Boolean): GlobalSerializationEnvironment { fun setGlobalSerialization(armed: Boolean): GlobalSerializationEnvironment {
return if (armed) { return if (armed) {
object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv("<global>") { object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv() {
override fun unset() { override fun unset() {
_globalSerializationEnv.set(null) _globalSerializationEnv.set(null)
inVMExecutors.remove(this) inVMExecutors.remove(this)

View File

@ -8,11 +8,10 @@ import net.corda.cliutils.CordaCliWrapper
import net.corda.cliutils.ExitCodes import net.corda.cliutils.ExitCodes
import net.corda.cliutils.start import net.corda.cliutils.start
import net.corda.core.internal.isRegularFile import net.corda.core.internal.isRegularFile
import net.corda.core.internal.rootMessage
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.SerializationEnvironmentImpl import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal._contextSerializationEnv import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.core.utilities.base64ToByteArray import net.corda.core.utilities.base64ToByteArray
import net.corda.core.utilities.hexToByteArray import net.corda.core.utilities.hexToByteArray
@ -22,7 +21,6 @@ import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme
import net.corda.serialization.internal.amqp.DeserializationInput import net.corda.serialization.internal.amqp.DeserializationInput
import net.corda.serialization.internal.amqp.amqpMagic import net.corda.serialization.internal.amqp.amqpMagic
import org.slf4j.event.Level import org.slf4j.event.Level
import picocli.CommandLine
import picocli.CommandLine.* import picocli.CommandLine.*
import java.io.PrintStream import java.io.PrintStream
import java.net.MalformedURLException import java.net.MalformedURLException
@ -128,7 +126,7 @@ class BlobInspector : CordaCliWrapper("blob-inspector", "Convert AMQP serialised
private fun initialiseSerialization() { private fun initialiseSerialization() {
// Deserialise with the lenient carpenter as we only care for the AMQP field getters // Deserialise with the lenient carpenter as we only care for the AMQP field getters
_contextSerializationEnv.set(SerializationEnvironmentImpl( _contextSerializationEnv.set(SerializationEnvironment.with(
SerializationFactoryImpl().apply { SerializationFactoryImpl().apply {
registerScheme(AMQPInspectorSerializationScheme) registerScheme(AMQPInspectorSerializationScheme)
}, },

View File

@ -2,7 +2,7 @@ package net.corda.demobench
import javafx.scene.image.Image import javafx.scene.image.Image
import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme
import net.corda.core.serialization.internal.SerializationEnvironmentImpl import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.demobench.views.DemoBenchView import net.corda.demobench.views.DemoBenchView
import net.corda.serialization.internal.AMQP_P2P_CONTEXT import net.corda.serialization.internal.AMQP_P2P_CONTEXT
@ -56,7 +56,7 @@ class DemoBench : App(DemoBenchView::class) {
} }
private fun initialiseSerialization() { private fun initialiseSerialization() {
nodeSerializationEnv = SerializationEnvironmentImpl( nodeSerializationEnv = SerializationEnvironment.with(
SerializationFactoryImpl().apply { SerializationFactoryImpl().apply {
registerScheme(AMQPClientSerializationScheme(emptyList())) registerScheme(AMQPClientSerializationScheme(emptyList()))
}, },

View File

@ -1,9 +1,10 @@
package net.corda.bootstrapper.serialization package net.corda.bootstrapper.serialization
import net.corda.core.serialization.internal.SerializationEnvironmentImpl import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoCheckpointSerializer
import net.corda.serialization.internal.AMQP_P2P_CONTEXT import net.corda.serialization.internal.AMQP_P2P_CONTEXT
import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT
import net.corda.serialization.internal.SerializationFactoryImpl import net.corda.serialization.internal.SerializationFactoryImpl
@ -14,14 +15,16 @@ class SerializationEngine {
synchronized(this) { synchronized(this) {
if (nodeSerializationEnv == null) { if (nodeSerializationEnv == null) {
val classloader = this::class.java.classLoader val classloader = this::class.java.classLoader
nodeSerializationEnv = SerializationEnvironmentImpl( nodeSerializationEnv = SerializationEnvironment.with(
SerializationFactoryImpl().apply { SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(emptyList())) registerScheme(AMQPServerSerializationScheme(emptyList()))
}, },
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), rpcServerContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader), storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader)
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader),
checkpointSerializer = KryoCheckpointSerializer
) )
} }
} }