CORDA-2006: Simplify checkpoint serialization (#4042)

* CORDA-2006: Simplify checkpoint serialization

* Supply rule to KryoTest
This commit is contained in:
Dominic Fox 2018-10-08 13:39:28 +01:00 committed by GitHub
parent c88d3d8c1b
commit d9ea19855f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 186 additions and 311 deletions

View File

@ -6,7 +6,6 @@ import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.*
import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.serialization.internal.*
import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme
@ -35,7 +34,7 @@ class AMQPClientSerializationScheme(
}
fun createSerializationEnv(classLoader: ClassLoader? = null): SerializationEnvironment {
return SerializationEnvironmentImpl(
return SerializationEnvironment.with(
SerializationFactoryImpl().apply {
registerScheme(AMQPClientSerializationScheme(emptyList()))
},

View File

@ -1,74 +0,0 @@
package net.corda.core.serialization.internal
import net.corda.core.KeepForDJVM
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.ByteSequence
import java.io.NotSerializableException
/**
* A deterministic version of [CheckpointSerializationFactory] that does not use thread-locals to manage serialization
* context.
*/
@KeepForDJVM
class CheckpointSerializationFactory(
private val scheme: CheckpointSerializationScheme
) {
val defaultContext: CheckpointSerializationContext get() = _currentContext ?: effectiveSerializationEnv.checkpointContext
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
/**
* Deserialize the bytes in to an object, using the prefixed bytes to determine the format.
*
* @param byteSequence The bytes to deserialize, including a format header prefix.
* @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown.
* @param context A context that configures various parameters to deserialization.
*/
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) }
}
/**
* Serialize an object to bytes using the preferred serialization format version from the context.
*
* @param obj The object to be serialized.
* @param context A context that configures various parameters to serialization, including the serialization format version.
*/
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return withCurrentContext(context) { scheme.serialize(obj, context) }
}
override fun toString(): String {
return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}"
}
override fun equals(other: Any?): Boolean {
return other is CheckpointSerializationFactory && other.scheme == this.scheme
}
override fun hashCode(): Int = scheme.hashCode()
private var _currentContext: CheckpointSerializationContext? = null
/**
* Change the current context inside the block to that supplied.
*/
fun <T> withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T {
val priorContext = _currentContext
if (context != null) _currentContext = context
try {
return block()
} finally {
if (context != null) _currentContext = priorContext
}
}
companion object {
/**
* A default factory for serialization/deserialization.
*/
val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory
}
}

View File

@ -4,7 +4,7 @@ import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.UseCase.P2P
import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.serialization.internal.*
import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme
@ -58,13 +58,11 @@ class LocalSerializationRule(private val label: String) : TestRule {
_contextSerializationEnv.set(null)
}
private fun createTestSerializationEnv(): SerializationEnvironmentImpl {
private fun createTestSerializationEnv(): SerializationEnvironment {
val factory = SerializationFactoryImpl(mutableMapOf()).apply {
registerScheme(AMQPSerializationScheme(emptySet(), AccessOrderLinkedHashMap(128)))
}
return object : SerializationEnvironmentImpl(factory, AMQP_P2P_CONTEXT) {
override fun toString() = "testSerializationEnv($label)"
}
return SerializationEnvironment.with(factory, AMQP_P2P_CONTEXT)
}
private class AMQPSerializationScheme(

View File

@ -13,75 +13,12 @@ import java.io.NotSerializableException
object CheckpointSerializationDefaults {
@DeleteForDJVM
val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext
val CHECKPOINT_SERIALIZATION_FACTORY get() = effectiveSerializationEnv.checkpointSerializationFactory
}
/**
* A class for serializing and deserializing objects at checkpoints, using Kryo serialization.
*/
@KeepForDJVM
class CheckpointSerializationFactory(
private val scheme: CheckpointSerializationScheme
) {
val defaultContext: CheckpointSerializationContext get() = _currentContext.get() ?: effectiveSerializationEnv.checkpointContext
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
/**
* Deserialize the bytes in to an object, using the prefixed bytes to determine the format.
*
* @param byteSequence The bytes to deserialize, including a format header prefix.
* @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown.
* @param context A context that configures various parameters to deserialization.
*/
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) }
}
/**
* Serialize an object to bytes using the preferred serialization format version from the context.
*
* @param obj The object to be serialized.
* @param context A context that configures various parameters to serialization, including the serialization format version.
*/
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return withCurrentContext(context) { scheme.serialize(obj, context) }
}
override fun toString(): String {
return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}"
}
override fun equals(other: Any?): Boolean {
return other is CheckpointSerializationFactory && other.scheme == this.scheme
}
override fun hashCode(): Int = scheme.hashCode()
private val _currentContext = ThreadLocal<CheckpointSerializationContext?>()
/**
* Change the current context inside the block to that supplied.
*/
fun <T> withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T {
val priorContext = _currentContext.get()
if (context != null) _currentContext.set(context)
try {
return block()
} finally {
if (context != null) _currentContext.set(priorContext)
}
}
companion object {
val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory
}
val CHECKPOINT_SERIALIZER get() = effectiveSerializationEnv.checkpointSerializer
}
@KeepForDJVM
@DoNotImplement
interface CheckpointSerializationScheme {
interface CheckpointSerializer {
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T
@ -167,32 +104,36 @@ interface CheckpointSerializationContext {
/*
* Convenience extension method for deserializing a ByteSequence, utilising the default factory.
*/
inline fun <reified T : Any> ByteSequence.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): T {
return serializationFactory.deserialize(this, T::class.java, context)
@JvmOverloads
inline fun <reified T : Any> ByteSequence.checkpointDeserialize(
context: CheckpointSerializationContext = effectiveSerializationEnv.checkpointContext): T {
return effectiveSerializationEnv.checkpointSerializer.deserialize(this, T::class.java, context)
}
/**
* Convenience extension method for deserializing SerializedBytes with type matching, utilising the default factory.
*/
inline fun <reified T : Any> SerializedBytes<T>.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): T {
return serializationFactory.deserialize(this, T::class.java, context)
@JvmOverloads
inline fun <reified T : Any> SerializedBytes<T>.checkpointDeserialize(
context: CheckpointSerializationContext = effectiveSerializationEnv.checkpointContext): T {
return effectiveSerializationEnv.checkpointSerializer.deserialize(this, T::class.java, context)
}
/**
* Convenience extension method for deserializing a ByteArray, utilising the default factory.
*/
inline fun <reified T : Any> ByteArray.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): T {
@JvmOverloads
inline fun <reified T : Any> ByteArray.checkpointDeserialize(
context: CheckpointSerializationContext = effectiveSerializationEnv.checkpointContext): T {
require(isNotEmpty()) { "Empty bytes" }
return this.sequence().checkpointDeserialize(serializationFactory, context)
return this.sequence().checkpointDeserialize(context)
}
/**
* Convenience extension method for serializing an object of type T, utilising the default factory.
*/
fun <T : Any> T.checkpointSerialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): SerializedBytes<T> {
return serializationFactory.serialize(this, context)
@JvmOverloads
fun <T : Any> T.checkpointSerialize(
context: CheckpointSerializationContext = effectiveSerializationEnv.checkpointContext): SerializedBytes<T> {
return effectiveSerializationEnv.checkpointSerializer.serialize(this, context)
}

View File

@ -11,38 +11,63 @@ import net.corda.core.serialization.SerializationFactory
@KeepForDJVM
interface SerializationEnvironment {
companion object {
fun with(
serializationFactory: SerializationFactory,
p2pContext: SerializationContext,
rpcServerContext: SerializationContext? = null,
rpcClientContext: SerializationContext? = null,
storageContext: SerializationContext? = null,
checkpointContext: CheckpointSerializationContext? = null,
checkpointSerializer: CheckpointSerializer? = null
): SerializationEnvironment =
SerializationEnvironmentImpl(
serializationFactory = serializationFactory,
p2pContext = p2pContext,
optionalRpcServerContext = rpcServerContext,
optionalRpcClientContext = rpcClientContext,
optionalStorageContext = storageContext,
optionalCheckpointContext = checkpointContext,
optionalCheckpointSerializer = checkpointSerializer
)
}
val serializationFactory: SerializationFactory
val checkpointSerializationFactory: CheckpointSerializationFactory
val p2pContext: SerializationContext
val rpcServerContext: SerializationContext
val rpcClientContext: SerializationContext
val storageContext: SerializationContext
val checkpointSerializer: CheckpointSerializer
val checkpointContext: CheckpointSerializationContext
}
@KeepForDJVM
open class SerializationEnvironmentImpl(
private class SerializationEnvironmentImpl(
override val serializationFactory: SerializationFactory,
override val p2pContext: SerializationContext,
rpcServerContext: SerializationContext? = null,
rpcClientContext: SerializationContext? = null,
storageContext: SerializationContext? = null,
checkpointContext: CheckpointSerializationContext? = null,
checkpointSerializationFactory: CheckpointSerializationFactory? = null) : SerializationEnvironment {
// Those that are passed in as null are never inited:
override lateinit var rpcServerContext: SerializationContext
override lateinit var rpcClientContext: SerializationContext
override lateinit var storageContext: SerializationContext
override lateinit var checkpointContext: CheckpointSerializationContext
override lateinit var checkpointSerializationFactory: CheckpointSerializationFactory
private val optionalRpcServerContext: SerializationContext? = null,
private val optionalRpcClientContext: SerializationContext? = null,
private val optionalStorageContext: SerializationContext? = null,
private val optionalCheckpointContext: CheckpointSerializationContext? = null,
private val optionalCheckpointSerializer: CheckpointSerializer? = null) : SerializationEnvironment {
init {
rpcServerContext?.let { this.rpcServerContext = it }
rpcClientContext?.let { this.rpcClientContext = it }
storageContext?.let { this.storageContext = it }
checkpointContext?.let { this.checkpointContext = it }
checkpointSerializationFactory?.let { this.checkpointSerializationFactory = it }
}
override val rpcServerContext: SerializationContext get() = optionalRpcServerContext ?:
throw UnsupportedOperationException("RPC server serialization not supported in this environment")
override val rpcClientContext: SerializationContext get() = optionalRpcClientContext ?:
throw UnsupportedOperationException("RPC client serialization not supported in this environment")
override val storageContext: SerializationContext get() = optionalStorageContext ?:
throw UnsupportedOperationException("Storage serialization not supported in this environment")
override val checkpointContext: CheckpointSerializationContext get() = optionalCheckpointContext ?:
throw UnsupportedOperationException("Checkpoint serialization not supported in this environment")
override val checkpointSerializer: CheckpointSerializer get() = optionalCheckpointSerializer ?:
throw UnsupportedOperationException("Checkpoint serialization not supported in this environment")
}
private val _nodeSerializationEnv = SimpleToggleField<SerializationEnvironment>("nodeSerializationEnv", true)

View File

@ -1,7 +1,5 @@
package net.corda.core.flows;
import net.corda.core.serialization.internal.CheckpointSerializationDefaults;
import net.corda.core.serialization.internal.CheckpointSerializationFactory;
import net.corda.core.serialization.SerializationDefaults;
import net.corda.core.serialization.SerializationFactory;
import net.corda.testing.core.SerializationEnvironmentRule;
@ -32,12 +30,10 @@ public class SerializationApiInJavaTest {
SerializationDefaults defaults = SerializationDefaults.INSTANCE;
SerializationFactory factory = defaults.getSERIALIZATION_FACTORY();
CheckpointSerializationDefaults checkpointDefaults = CheckpointSerializationDefaults.INSTANCE;
CheckpointSerializationFactory checkpointSerializationFactory = checkpointDefaults.getCHECKPOINT_SERIALIZATION_FACTORY();
serialize("hello", factory, defaults.getP2P_CONTEXT());
serialize("hello", factory, defaults.getRPC_SERVER_CONTEXT());
serialize("hello", factory, defaults.getRPC_CLIENT_CONTEXT());
serialize("hello", factory, defaults.getSTORAGE_CONTEXT());
checkpointSerialize("hello", checkpointSerializationFactory, checkpointDefaults.getCHECKPOINT_CONTEXT());
checkpointSerialize("hello");
}
}

View File

@ -14,7 +14,7 @@ import net.corda.core.node.services.AttachmentId
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.core.utilities.days
import net.corda.core.utilities.getOrThrow
@ -393,7 +393,7 @@ internal constructor(private val initSerEnv: Boolean,
// We need to to set serialization env, because generation of parameters is run from Cordform.
private fun initialiseSerialization() {
_contextSerializationEnv.set(SerializationEnvironmentImpl(
_contextSerializationEnv.set(SerializationEnvironment.with(
SerializationFactoryImpl().apply {
registerScheme(AMQPParametersSerializationScheme)
},

View File

@ -4,7 +4,6 @@ import net.corda.core.cordapp.Cordapp
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.node.services.api.CheckpointStorage
@ -21,7 +20,7 @@ object CheckpointVerifier {
*/
fun verifyCheckpointsCompatible(checkpointStorage: CheckpointStorage, currentCordapps: List<Cordapp>, platformVersion: Int, serviceHub: ServiceHub, tokenizableServices: List<Any>) {
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZER, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
checkpointStorage.getAllCheckpoints().forEach { (_, serializedCheckpoint) ->

View File

@ -21,8 +21,7 @@ import net.corda.core.messaging.RPCOps
import net.corda.core.node.NetworkParameters
import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
@ -38,7 +37,7 @@ import net.corda.node.internal.security.RPCSecurityManagerImpl
import net.corda.node.internal.security.RPCSecurityManagerWithAdditionalUser
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoSerializationScheme
import net.corda.node.serialization.kryo.KryoCheckpointSerializer
import net.corda.node.services.Permissions
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal
@ -470,17 +469,19 @@ open class Node(configuration: NodeConfiguration,
private fun initialiseSerialization() {
if (!initialiseSerialization) return
val classloader = cordappLoader.appClassLoader
nodeSerializationEnv = SerializationEnvironmentImpl(
nodeSerializationEnv = SerializationEnvironment.with(
SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps))
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps))
},
checkpointSerializationFactory = CheckpointSerializationFactory(KryoSerializationScheme),
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),
rpcClientContext = if (configuration.shouldInitCrashShell()) AMQP_RPC_CLIENT_CONTEXT.withClassLoader(classloader) else null, //even Shell embeded in the node connects via RPC to the node
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader),
rpcClientContext = if (configuration.shouldInitCrashShell()) AMQP_RPC_CLIENT_CONTEXT.withClassLoader(classloader) else null) //even Shell embeded in the node connects via RPC to the node
checkpointSerializer = KryoCheckpointSerializer,
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader)
)
}
/** Starts a blocking event loop for message dispatch. */

View File

@ -12,10 +12,9 @@ import com.esotericsoftware.kryo.serializers.ClosureSerializer
import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationScheme
import net.corda.core.serialization.internal.CheckpointSerializer
import net.corda.core.utilities.ByteSequence
import net.corda.serialization.internal.*
import java.security.PublicKey
import java.util.concurrent.ConcurrentHashMap
val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0))
@ -31,7 +30,7 @@ private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>()
override fun read(kryo: Kryo, input: Input, type: Class<AutoCloseable>) = throw IllegalStateException("Should not reach here!")
}
object KryoSerializationScheme : CheckpointSerializationScheme {
object KryoCheckpointSerializer : CheckpointSerializer {
private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>()
private fun getPool(context: CheckpointSerializationContext): KryoPool {

View File

@ -127,7 +127,7 @@ class SingleThreadedStateMachineManager(
override fun start(tokenizableServices: List<Any>) {
checkQuasarJavaAgentPresence()
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZER, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
this.checkpointSerializationContext = checkpointSerializationContext
this.actionExecutor = makeActionExecutor(checkpointSerializationContext)

View File

@ -22,7 +22,7 @@ import net.corda.core.internal.notary.isConsumedByTheSameTx
import net.corda.core.internal.notary.validateTimeWindow
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.contextLogger
@ -201,7 +201,7 @@ class RaftTransactionCommitLog<E, EK>(
class CordaKryoSerializer<T : Any> : TypeSerializer<T> {
private val context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY)
private val factory = CheckpointSerializationFactory.defaultFactory
private val checkpointSerializer = CheckpointSerializationDefaults.CHECKPOINT_SERIALIZER
override fun write(obj: T, buffer: BufferOutput<*>, serializer: Serializer) {
val serialized = obj.checkpointSerialize(context = context)
@ -213,7 +213,7 @@ class RaftTransactionCommitLog<E, EK>(
val size = buffer.readInt()
val serialized = ByteArray(size)
buffer.read(serialized)
return factory.deserialize(ByteSequence.of(serialized), type, context)
return checkpointSerializer.deserialize(ByteSequence.of(serialized), type, context)
}
}
}

View File

@ -13,7 +13,6 @@ import net.corda.core.crypto.*
import net.corda.core.internal.FetchDataFlow
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence
@ -23,11 +22,13 @@ import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.serialization.internal.*
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.TestIdentity
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import net.corda.testing.internal.rigorousMock
import org.assertj.core.api.Assertions.*
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@ -48,12 +49,12 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
}
private lateinit var factory: CheckpointSerializationFactory
@get:Rule
val serializationRule = CheckpointSerializationEnvironmentRule()
private lateinit var context: CheckpointSerializationContext
@Before
fun setup() {
factory = CheckpointSerializationFactory(KryoSerializationScheme)
context = CheckpointSerializationContextImpl(
javaClass.classLoader,
AllWhitelist,
@ -69,15 +70,15 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `simple data class`() {
val birthday = Instant.parse("1984-04-17T00:30:00.00Z")
val mike = Person("mike", birthday)
val bits = mike.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("mike", birthday))
val bits = mike.checkpointSerialize(context)
assertThat(bits.checkpointDeserialize(context)).isEqualTo(Person("mike", birthday))
}
@Test
fun `null values`() {
val bob = Person("bob", null)
val bits = bob.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("bob", null))
val bits = bob.checkpointSerialize(context)
assertThat(bits.checkpointDeserialize(context)).isEqualTo(Person("bob", null))
}
@Test
@ -85,10 +86,10 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val noReferencesContext = context.withoutReferences()
val obj : ByteSequence = Ints.toByteArray(0x01234567).sequence()
val originalList : ArrayList<ByteSequence> = ArrayList<ByteSequence>().apply { this += obj }
val deserialisedList = originalList.checkpointSerialize(factory, noReferencesContext).checkpointDeserialize(factory, noReferencesContext)
val deserialisedList = originalList.checkpointSerialize(noReferencesContext).checkpointDeserialize(noReferencesContext)
originalList += obj
deserialisedList += obj
assertThat(deserialisedList.checkpointSerialize(factory, noReferencesContext)).isEqualTo(originalList.checkpointSerialize(factory, noReferencesContext))
assertThat(deserialisedList.checkpointSerialize(noReferencesContext)).isEqualTo(originalList.checkpointSerialize(noReferencesContext))
}
@Test
@ -105,14 +106,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
this += instant
this += instant
}
assertThat(listWithSameInstances.checkpointSerialize(factory, noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(factory, noReferencesContext))
assertThat(listWithSameInstances.checkpointSerialize(noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(noReferencesContext))
}
@Test
fun `cyclic object graph`() {
val cyclic = Cyclic(3)
val bits = cyclic.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(cyclic)
val bits = cyclic.checkpointSerialize(context)
assertThat(bits.checkpointDeserialize(context)).isEqualTo(cyclic)
}
@Test
@ -124,7 +125,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
signature.verify(bitsToSign)
assertThatThrownBy { signature.verify(wrongBits) }
val deserialisedKeyPair = keyPair.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val deserialisedKeyPair = keyPair.checkpointSerialize(context).checkpointDeserialize(context)
val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign)
deserialisedSignature.verify(bitsToSign)
assertThatThrownBy { deserialisedSignature.verify(wrongBits) }
@ -132,28 +133,28 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `write and read Kotlin object singleton`() {
val serialised = TestSingleton.checkpointSerialize(factory, context)
val deserialised = serialised.checkpointDeserialize(factory, context)
val serialised = TestSingleton.checkpointSerialize(context)
val deserialised = serialised.checkpointDeserialize(context)
assertThat(deserialised).isSameAs(TestSingleton)
}
@Test
fun `check Kotlin EmptyList can be serialised`() {
val deserialisedList: List<Int> = emptyList<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val deserialisedList: List<Int> = emptyList<Int>().checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(0, deserialisedList.size)
assertEquals<Any>(Collections.emptyList<Int>().javaClass, deserialisedList.javaClass)
}
@Test
fun `check Kotlin EmptySet can be serialised`() {
val deserialisedSet: Set<Int> = emptySet<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val deserialisedSet: Set<Int> = emptySet<Int>().checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(0, deserialisedSet.size)
assertEquals<Any>(Collections.emptySet<Int>().javaClass, deserialisedSet.javaClass)
}
@Test
fun `check Kotlin EmptyMap can be serialised`() {
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(0, deserialisedMap.size)
assertEquals<Any>(Collections.emptyMap<Int, Int>().javaClass, deserialisedMap.javaClass)
}
@ -161,7 +162,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `InputStream serialisation`() {
val rubbish = ByteArray(12345) { (it * it * 0.12345).toByte() }
val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(context).checkpointDeserialize(context)
for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
@ -171,7 +172,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `InputStream serialisation does not write trailing garbage`() {
val byteArrays = listOf("123", "456").map { it.toByteArray() }
val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(factory, context).checkpointDeserialize(factory, context).iterator()
val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(context).checkpointDeserialize(context).iterator()
byteArrays.forEach { assertArrayEquals(it, streams.next().readBytes()) }
assertFalse(streams.hasNext())
}
@ -182,8 +183,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val testBytes = testString.toByteArray()
val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID))
val serializedMetaData = meta.checkpointSerialize(factory, context).bytes
val meta2 = serializedMetaData.checkpointDeserialize<SignableData>(factory, context)
val serializedMetaData = meta.checkpointSerialize(context).bytes
val meta2 = serializedMetaData.checkpointDeserialize<SignableData>(context)
assertEquals(meta2, meta)
}
@ -191,7 +192,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `serialize - deserialize Logger`() {
val storageContext: CheckpointSerializationContext = context
val logger = LoggerFactory.getLogger("aName")
val logger2 = logger.checkpointSerialize(factory, storageContext).checkpointDeserialize(factory, storageContext)
val logger2 = logger.checkpointSerialize(storageContext).checkpointDeserialize(storageContext)
assertEquals(logger.name, logger2.name)
assertTrue(logger === logger2)
}
@ -203,7 +204,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
SecureHash.sha256(rubbish),
rubbish.size,
rubbish.inputStream()
).checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
).checkpointSerialize(context).checkpointDeserialize(context)
for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
@ -230,8 +231,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32
))
val serializedBytes = expected.checkpointSerialize(factory, context)
val actual = serializedBytes.checkpointDeserialize(factory, context)
val serializedBytes = expected.checkpointSerialize(context)
val actual = serializedBytes.checkpointDeserialize(context)
assertEquals(expected, actual)
}
@ -278,14 +279,13 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
}
}
Tmp()
val factory = CheckpointSerializationFactory(KryoSerializationScheme)
val context = CheckpointSerializationContextImpl(
javaClass.classLoader,
AllWhitelist,
emptyMap(),
true,
null)
pt.checkpointSerialize(factory, context)
pt.checkpointSerialize(context)
}
@Test
@ -293,7 +293,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val exception = IllegalArgumentException("fooBar")
val toBeSuppressedOnSenderSide = IllegalStateException("bazz1")
exception.addSuppressed(toBeSuppressedOnSenderSide)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val exception2 = exception.checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(exception.message, exception2.message)
assertEquals(1, exception2.suppressed.size)
@ -308,7 +308,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `serialize - deserialize Exception no suppressed`() {
val exception = IllegalArgumentException("fooBar")
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val exception2 = exception.checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(exception.message, exception2.message)
assertEquals(0, exception2.suppressed.size)
@ -322,7 +322,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `serialize - deserialize HashNotFound`() {
val randomHash = SecureHash.randomSHA256()
val exception = FetchDataFlow.HashNotFound(randomHash)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val exception2 = exception.checkpointSerialize(context).checkpointDeserialize(context)
assertEquals(randomHash, exception2.requested)
}
@ -330,17 +330,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `compression has the desired effect`() {
compression ?: return
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
val compressed = data.checkpointSerialize(factory, context)
val compressed = data.checkpointSerialize(context)
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
assertArrayEquals(data, compressed.checkpointDeserialize(factory, context))
assertArrayEquals(data, compressed.checkpointDeserialize(context))
}
@Test
fun `a particular encoding can be banned for deserialization`() {
compression ?: return
doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression)
val compressed = "whatever".checkpointSerialize(factory, context)
catchThrowable { compressed.checkpointDeserialize(factory, context) }.run {
val compressed = "whatever".checkpointSerialize(context)
catchThrowable { compressed.checkpointDeserialize(context) }.run {
assertSame<Any>(KryoException::class.java, javaClass)
assertEquals(encodingNotPermittedFormat.format(compression), message)
}
@ -351,8 +351,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
class Holder(val holder: ByteArray)
val obj = Holder(ByteArray(20000))
val uncompressedSize = obj.checkpointSerialize(factory, context.withEncoding(null)).size
val compressedSize = obj.checkpointSerialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
val uncompressedSize = obj.checkpointSerialize(context.withEncoding(null)).size
val compressedSize = obj.checkpointSerialize(context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
// If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade.
assertEquals(20222, uncompressedSize)
assertEquals(1111, compressedSize)

View File

@ -5,7 +5,7 @@ import net.corda.core.DeleteForDJVM
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.CheckpointSerializer
val serializationContextKey = SerializeAsTokenContext::class.java
@ -70,8 +70,8 @@ class SerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: Ser
*/
@DeleteForDJVM
class CheckpointSerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: SerializeAsTokenContext.() -> Unit) : SerializeAsTokenContext {
constructor(toBeTokenized: Any, serializationFactory: CheckpointSerializationFactory, context: CheckpointSerializationContext, serviceHub: ServiceHub) : this(serviceHub, {
serializationFactory.serialize(toBeTokenized, context.withTokenContext(this))
constructor(toBeTokenized: Any, serializer: CheckpointSerializer, context: CheckpointSerializationContext, serviceHub: ServiceHub) : this(serviceHub, {
serializer.serialize(toBeTokenized, context.withTokenContext(this))
})
private val classNameToSingleton = mutableMapOf<String, SerializeAsToken>()

View File

@ -2,14 +2,14 @@ package net.corda.serialization.internal;
import net.corda.core.serialization.*;
import net.corda.core.serialization.internal.CheckpointSerializationContext;
import net.corda.core.serialization.internal.CheckpointSerializationFactory;
import net.corda.core.serialization.internal.CheckpointSerializer;
import net.corda.node.serialization.kryo.CordaClosureSerializer;
import net.corda.testing.core.SerializationEnvironmentRule;
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import java.io.NotSerializableException;
import java.io.Serializable;
import java.util.Collections;
import java.util.concurrent.Callable;
@ -23,12 +23,11 @@ public final class LambdaCheckpointSerializationTest {
public final CheckpointSerializationEnvironmentRule testCheckpointSerialization =
new CheckpointSerializationEnvironmentRule();
private CheckpointSerializationFactory factory;
private CheckpointSerializationContext context;
private CheckpointSerializer serializer;
@Before
public void setup() {
factory = testCheckpointSerialization.getCheckpointSerializationFactory();
context = new CheckpointSerializationContextImpl(
getClass().getClassLoader(),
AllWhitelist.INSTANCE,
@ -36,6 +35,8 @@ public final class LambdaCheckpointSerializationTest {
true,
null
);
serializer = testCheckpointSerialization.getCheckpointSerializer();
}
@Test
@ -63,11 +64,11 @@ public final class LambdaCheckpointSerializationTest {
assertThat(throwable).hasMessage(CordaClosureSerializer.ERROR_MESSAGE);
}
private <T> SerializedBytes<T> serialize(final T target) {
return factory.serialize(target, context);
private <T> SerializedBytes<T> serialize(final T target) throws NotSerializableException {
return serializer.serialize(target, context);
}
private <T> T deserialize(final SerializedBytes<? extends T> bytes, final Class<T> type) {
return factory.deserialize(bytes, type, context);
private <T> T deserialize(final SerializedBytes<? extends T> bytes, final Class<T> type) throws NotSerializableException {
return serializer.deserialize(bytes, type, context);
}
}

View File

@ -4,11 +4,9 @@ import net.corda.core.contracts.ContractAttachment
import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.testing.contracts.DummyContract
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.MockServices
@ -27,24 +25,25 @@ class ContractAttachmentSerializerTest {
@JvmField
val testCheckpointSerialization = CheckpointSerializationEnvironmentRule()
private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: CheckpointSerializationContext
private lateinit var contextWithToken: CheckpointSerializationContext
private val mockServices = MockServices(emptyList(), CordaX500Name("MegaCorp", "London", "GB"), rigorousMock())
@Before
fun setup() {
factory = testCheckpointSerialization.checkpointSerializationFactory
context = testCheckpointSerialization.checkpointSerializationContext
contextWithToken = context.withTokenContext(CheckpointSerializeAsTokenContextImpl(Any(), factory, context, mockServices))
contextWithToken = testCheckpointSerialization.checkpointSerializationContext.withTokenContext(
CheckpointSerializeAsTokenContextImpl(
Any(),
testCheckpointSerialization.checkpointSerializer,
testCheckpointSerialization.checkpointSerializationContext,
mockServices))
}
@Test
fun `write contract attachment and read it back`() {
val contractAttachment = ContractAttachment(GeneratedAttachment(EMPTY_BYTE_ARRAY), DummyContract.PROGRAM_ID)
// no token context so will serialize the whole attachment
val serialized = contractAttachment.checkpointSerialize(factory, context)
val deserialized = serialized.checkpointDeserialize(factory, context)
val serialized = contractAttachment.checkpointSerialize()
val deserialized = serialized.checkpointDeserialize()
assertEquals(contractAttachment.id, deserialized.attachment.id)
assertEquals(contractAttachment.contract, deserialized.contract)
@ -59,8 +58,8 @@ class ContractAttachmentSerializerTest {
mockServices.attachments.importAttachment(attachment.open(), "test", null)
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
val deserialized = serialized.checkpointDeserialize(factory, contextWithToken)
val serialized = contractAttachment.checkpointSerialize(contextWithToken)
val deserialized = serialized.checkpointDeserialize(contextWithToken)
assertEquals(contractAttachment.id, deserialized.attachment.id)
assertEquals(contractAttachment.contract, deserialized.contract)
@ -76,7 +75,7 @@ class ContractAttachmentSerializerTest {
mockServices.attachments.importAttachment(attachment.open(), "test", null)
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
val serialized = contractAttachment.checkpointSerialize(contextWithToken)
assertThat(serialized.size).isLessThan(largeAttachmentSize)
}
@ -88,8 +87,8 @@ class ContractAttachmentSerializerTest {
// don't importAttachment in mockService
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
val deserialized = serialized.checkpointDeserialize(factory, contextWithToken)
val serialized = contractAttachment.checkpointSerialize(contextWithToken)
val deserialized = serialized.checkpointDeserialize(contextWithToken)
assertThatThrownBy { deserialized.attachment.open() }.isInstanceOf(MissingAttachmentsException::class.java)
}
@ -100,8 +99,8 @@ class ContractAttachmentSerializerTest {
// don't importAttachment in mockService
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
serialized.checkpointDeserialize(factory, contextWithToken)
val serialized = contractAttachment.checkpointSerialize(contextWithToken)
serialized.checkpointDeserialize(contextWithToken)
// MissingAttachmentsException thrown if we try to open attachment
}

View File

@ -5,7 +5,6 @@ import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.io.Output
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.OpaqueBytes
@ -14,7 +13,6 @@ import net.corda.node.serialization.kryo.CordaKryo
import net.corda.node.serialization.kryo.DefaultKryoCustomizer
import net.corda.node.serialization.kryo.kryoMagic
import net.corda.testing.internal.rigorousMock
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThat
import org.junit.Before
@ -28,12 +26,10 @@ class SerializationTokenTest {
@JvmField
val testCheckpointSerialization = CheckpointSerializationEnvironmentRule()
private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: CheckpointSerializationContext
@Before
fun setup() {
factory = testCheckpointSerialization.checkpointSerializationFactory
context = testCheckpointSerialization.checkpointSerializationContext.withWhitelisted(SingletonSerializationToken::class.java)
}
@ -49,16 +45,16 @@ class SerializationTokenTest {
override fun equals(other: Any?) = other is LargeTokenizable && other.bytes.size == this.bytes.size
}
private fun serializeAsTokenContext(toBeTokenized: Any) = CheckpointSerializeAsTokenContextImpl(toBeTokenized, factory, context, rigorousMock())
private fun serializeAsTokenContext(toBeTokenized: Any) = CheckpointSerializeAsTokenContextImpl(toBeTokenized, testCheckpointSerialization.checkpointSerializer, context, rigorousMock())
@Test
fun `write token and read tokenizable`() {
val tokenizableBefore = LargeTokenizable()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext)
val serializedBytes = tokenizableBefore.checkpointSerialize(testContext)
assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes)
val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext)
val tokenizableAfter = serializedBytes.checkpointDeserialize(testContext)
assertThat(tokenizableAfter).isSameAs(tokenizableBefore)
}
@ -69,8 +65,8 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext)
val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext)
val serializedBytes = tokenizableBefore.checkpointSerialize(testContext)
val tokenizableAfter = serializedBytes.checkpointDeserialize(testContext)
assertThat(tokenizableAfter).isSameAs(tokenizableBefore)
}
@ -79,7 +75,7 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(emptyList<Any>())
val testContext = this.context.withTokenContext(context)
tokenizableBefore.checkpointSerialize(factory, testContext)
tokenizableBefore.checkpointSerialize(testContext)
}
@Test(expected = UnsupportedOperationException::class)
@ -87,14 +83,14 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(emptyList<Any>())
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList<Any>())).checkpointSerialize(factory, testContext)
serializedBytes.checkpointDeserialize(factory, testContext)
val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList<Any>())).checkpointSerialize(testContext)
serializedBytes.checkpointDeserialize(testContext)
}
@Test(expected = KryoException::class)
fun `no context set`() {
val tokenizableBefore = UnitSerializeAsToken()
tokenizableBefore.checkpointSerialize(factory, context)
tokenizableBefore.checkpointSerialize(context)
}
@Test(expected = KryoException::class)
@ -112,7 +108,7 @@ class SerializationTokenTest {
kryo.writeObject(it, emptyList<Any>())
}
val serializedBytes = SerializedBytes<Any>(stream.toByteArray())
serializedBytes.checkpointDeserialize(factory, testContext)
serializedBytes.checkpointDeserialize(testContext)
}
private class WrongTypeSerializeAsToken : SerializeAsToken {
@ -128,7 +124,7 @@ class SerializationTokenTest {
val tokenizableBefore = WrongTypeSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext)
serializedBytes.checkpointDeserialize(factory, testContext)
val serializedBytes = tokenizableBefore.checkpointSerialize(testContext)
serializedBytes.checkpointDeserialize(testContext)
}
}

View File

@ -3,9 +3,7 @@ package net.corda.testing.core
import com.nhaarman.mockito_kotlin.any
import com.nhaarman.mockito_kotlin.doAnswer
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.DoNotImplement
import net.corda.core.internal.staticField
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.testing.common.internal.asContextEnv
@ -40,7 +38,7 @@ class SerializationEnvironmentRule(private val inheritable: Boolean = false) : T
/** Do not call, instead use [SerializationEnvironmentRule] as a [org.junit.Rule]. */
fun <T> run(taskLabel: String, task: (SerializationEnvironment) -> T): T {
return SerializationEnvironmentRule().apply { init(taskLabel) }.runTask(task)
return SerializationEnvironmentRule().apply { init() }.runTask(task)
}
}
@ -48,14 +46,14 @@ class SerializationEnvironmentRule(private val inheritable: Boolean = false) : T
val serializationFactory get() = env.serializationFactory
override fun apply(base: Statement, description: Description): Statement {
init(description.toString())
init()
return object : Statement() {
override fun evaluate() = runTask { base.evaluate() }
}
}
private fun init(envLabel: String) {
env = createTestSerializationEnv(envLabel)
private fun init() {
env = createTestSerializationEnv()
}
private fun <T> runTask(task: (SerializationEnvironment) -> T): T {

View File

@ -39,7 +39,7 @@ class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean =
/** Do not call, instead use [SerializationEnvironmentRule] as a [org.junit.Rule]. */
fun <T> run(taskLabel: String, task: (SerializationEnvironment) -> T): T {
return CheckpointSerializationEnvironmentRule().apply { init(taskLabel) }.runTask(task)
return CheckpointSerializationEnvironmentRule().apply { init() }.runTask(task)
}
}
@ -47,14 +47,14 @@ class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean =
private lateinit var env: SerializationEnvironment
override fun apply(base: Statement, description: Description): Statement {
init(description.toString())
init()
return object : Statement() {
override fun evaluate() = runTask { base.evaluate() }
}
}
private fun init(envLabel: String) {
env = createTestSerializationEnv(envLabel)
private fun init() {
env = createTestSerializationEnv()
}
private fun <T> runTask(task: (SerializationEnvironment) -> T): T {
@ -65,7 +65,6 @@ class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean =
}
}
val checkpointSerializationFactory get() = env.checkpointSerializationFactory
val checkpointSerializationContext get() = env.checkpointContext
val checkpointSerializer get() = env.checkpointSerializer
}

View File

@ -4,11 +4,10 @@ import com.nhaarman.mockito_kotlin.doNothing
import com.nhaarman.mockito_kotlin.whenever
import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme
import net.corda.core.DoNotImplement
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.*
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoSerializationScheme
import net.corda.node.serialization.kryo.KryoCheckpointSerializer
import net.corda.serialization.internal.*
import net.corda.testing.core.SerializationEnvironmentRule
import java.util.concurrent.ConcurrentHashMap
@ -30,22 +29,20 @@ fun <T> withoutTestSerialization(callable: () -> T): T { // TODO: Delete this, s
}
}
internal fun createTestSerializationEnv(label: String): SerializationEnvironmentImpl {
internal fun createTestSerializationEnv(): SerializationEnvironment {
val factory = SerializationFactoryImpl().apply {
registerScheme(AMQPClientSerializationScheme(emptyList()))
registerScheme(AMQPServerSerializationScheme(emptyList()))
}
return object : SerializationEnvironmentImpl(
return SerializationEnvironment.with(
factory,
AMQP_P2P_CONTEXT,
AMQP_RPC_SERVER_CONTEXT,
AMQP_RPC_CLIENT_CONTEXT,
AMQP_STORAGE_CONTEXT,
KRYO_CHECKPOINT_CONTEXT,
CheckpointSerializationFactory(KryoSerializationScheme)
) {
override fun toString() = "testSerializationEnv($label)"
}
KryoCheckpointSerializer
)
}
/**
@ -54,7 +51,7 @@ internal fun createTestSerializationEnv(label: String): SerializationEnvironment
*/
fun setGlobalSerialization(armed: Boolean): GlobalSerializationEnvironment {
return if (armed) {
object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv("<global>") {
object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv() {
override fun unset() {
_globalSerializationEnv.set(null)
inVMExecutors.remove(this)

View File

@ -8,11 +8,10 @@ import net.corda.cliutils.CordaCliWrapper
import net.corda.cliutils.ExitCodes
import net.corda.cliutils.start
import net.corda.core.internal.isRegularFile
import net.corda.core.internal.rootMessage
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.core.utilities.base64ToByteArray
import net.corda.core.utilities.hexToByteArray
@ -22,7 +21,6 @@ import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme
import net.corda.serialization.internal.amqp.DeserializationInput
import net.corda.serialization.internal.amqp.amqpMagic
import org.slf4j.event.Level
import picocli.CommandLine
import picocli.CommandLine.*
import java.io.PrintStream
import java.net.MalformedURLException
@ -128,7 +126,7 @@ class BlobInspector : CordaCliWrapper("blob-inspector", "Convert AMQP serialised
private fun initialiseSerialization() {
// Deserialise with the lenient carpenter as we only care for the AMQP field getters
_contextSerializationEnv.set(SerializationEnvironmentImpl(
_contextSerializationEnv.set(SerializationEnvironment.with(
SerializationFactoryImpl().apply {
registerScheme(AMQPInspectorSerializationScheme)
},

View File

@ -2,7 +2,7 @@ package net.corda.demobench
import javafx.scene.image.Image
import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.demobench.views.DemoBenchView
import net.corda.serialization.internal.AMQP_P2P_CONTEXT
@ -56,7 +56,7 @@ class DemoBench : App(DemoBenchView::class) {
}
private fun initialiseSerialization() {
nodeSerializationEnv = SerializationEnvironmentImpl(
nodeSerializationEnv = SerializationEnvironment.with(
SerializationFactoryImpl().apply {
registerScheme(AMQPClientSerializationScheme(emptyList()))
},

View File

@ -1,9 +1,10 @@
package net.corda.bootstrapper.serialization
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoCheckpointSerializer
import net.corda.serialization.internal.AMQP_P2P_CONTEXT
import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT
import net.corda.serialization.internal.SerializationFactoryImpl
@ -14,14 +15,16 @@ class SerializationEngine {
synchronized(this) {
if (nodeSerializationEnv == null) {
val classloader = this::class.java.classLoader
nodeSerializationEnv = SerializationEnvironmentImpl(
nodeSerializationEnv = SerializationEnvironment.with(
SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(emptyList()))
},
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader)
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader),
checkpointSerializer = KryoCheckpointSerializer
)
}
}