diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/serialization/amqp/AMQPClientSerializationScheme.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/serialization/amqp/AMQPClientSerializationScheme.kt index df12645479..a0e2bfc307 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/serialization/amqp/AMQPClientSerializationScheme.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/serialization/amqp/AMQPClientSerializationScheme.kt @@ -6,7 +6,6 @@ import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext.* import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.internal.SerializationEnvironment -import net.corda.core.serialization.internal.SerializationEnvironmentImpl import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.serialization.internal.* import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme @@ -35,7 +34,7 @@ class AMQPClientSerializationScheme( } fun createSerializationEnv(classLoader: ClassLoader? = null): SerializationEnvironment { - return SerializationEnvironmentImpl( + return SerializationEnvironment.with( SerializationFactoryImpl().apply { registerScheme(AMQPClientSerializationScheme(emptyList())) }, diff --git a/core-deterministic/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationFactory.kt b/core-deterministic/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationFactory.kt deleted file mode 100644 index dbb6fb54c0..0000000000 --- a/core-deterministic/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationFactory.kt +++ /dev/null @@ -1,74 +0,0 @@ -package net.corda.core.serialization.internal - -import net.corda.core.KeepForDJVM -import net.corda.core.serialization.SerializedBytes -import net.corda.core.utilities.ByteSequence -import java.io.NotSerializableException - -/** - * A deterministic version of [CheckpointSerializationFactory] that does not use thread-locals to manage serialization - * context. - */ -@KeepForDJVM -class CheckpointSerializationFactory( - private val scheme: CheckpointSerializationScheme -) { - - val defaultContext: CheckpointSerializationContext get() = _currentContext ?: effectiveSerializationEnv.checkpointContext - - private val creator: List = Exception().stackTrace.asList() - - /** - * Deserialize the bytes in to an object, using the prefixed bytes to determine the format. - * - * @param byteSequence The bytes to deserialize, including a format header prefix. - * @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown. - * @param context A context that configures various parameters to deserialization. - */ - @Throws(NotSerializableException::class) - fun deserialize(byteSequence: ByteSequence, clazz: Class, context: CheckpointSerializationContext): T { - return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) } - } - - /** - * Serialize an object to bytes using the preferred serialization format version from the context. - * - * @param obj The object to be serialized. - * @param context A context that configures various parameters to serialization, including the serialization format version. - */ - fun serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes { - return withCurrentContext(context) { scheme.serialize(obj, context) } - } - - override fun toString(): String { - return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}" - } - - override fun equals(other: Any?): Boolean { - return other is CheckpointSerializationFactory && other.scheme == this.scheme - } - - override fun hashCode(): Int = scheme.hashCode() - - private var _currentContext: CheckpointSerializationContext? = null - - /** - * Change the current context inside the block to that supplied. - */ - fun withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T { - val priorContext = _currentContext - if (context != null) _currentContext = context - try { - return block() - } finally { - if (context != null) _currentContext = priorContext - } - } - - companion object { - /** - * A default factory for serialization/deserialization. - */ - val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory - } -} \ No newline at end of file diff --git a/core-deterministic/testing/verifier/src/main/kotlin/net/corda/deterministic/verifier/LocalSerializationRule.kt b/core-deterministic/testing/verifier/src/main/kotlin/net/corda/deterministic/verifier/LocalSerializationRule.kt index 15848a4be4..e388c7c6d9 100644 --- a/core-deterministic/testing/verifier/src/main/kotlin/net/corda/deterministic/verifier/LocalSerializationRule.kt +++ b/core-deterministic/testing/verifier/src/main/kotlin/net/corda/deterministic/verifier/LocalSerializationRule.kt @@ -4,7 +4,7 @@ import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext.UseCase.P2P import net.corda.core.serialization.SerializationCustomSerializer -import net.corda.core.serialization.internal.SerializationEnvironmentImpl +import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.core.serialization.internal._contextSerializationEnv import net.corda.serialization.internal.* import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme @@ -58,13 +58,11 @@ class LocalSerializationRule(private val label: String) : TestRule { _contextSerializationEnv.set(null) } - private fun createTestSerializationEnv(): SerializationEnvironmentImpl { + private fun createTestSerializationEnv(): SerializationEnvironment { val factory = SerializationFactoryImpl(mutableMapOf()).apply { registerScheme(AMQPSerializationScheme(emptySet(), AccessOrderLinkedHashMap(128))) } - return object : SerializationEnvironmentImpl(factory, AMQP_P2P_CONTEXT) { - override fun toString() = "testSerializationEnv($label)" - } + return SerializationEnvironment.with(factory, AMQP_P2P_CONTEXT) } private class AMQPSerializationScheme( diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationAPI.kt index 448d1ab25f..6769b73b03 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationAPI.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationAPI.kt @@ -13,75 +13,12 @@ import java.io.NotSerializableException object CheckpointSerializationDefaults { @DeleteForDJVM val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext - val CHECKPOINT_SERIALIZATION_FACTORY get() = effectiveSerializationEnv.checkpointSerializationFactory -} - -/** - * A class for serializing and deserializing objects at checkpoints, using Kryo serialization. - */ -@KeepForDJVM -class CheckpointSerializationFactory( - private val scheme: CheckpointSerializationScheme -) { - - val defaultContext: CheckpointSerializationContext get() = _currentContext.get() ?: effectiveSerializationEnv.checkpointContext - - private val creator: List = Exception().stackTrace.asList() - - /** - * Deserialize the bytes in to an object, using the prefixed bytes to determine the format. - * - * @param byteSequence The bytes to deserialize, including a format header prefix. - * @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown. - * @param context A context that configures various parameters to deserialization. - */ - fun deserialize(byteSequence: ByteSequence, clazz: Class, context: CheckpointSerializationContext): T { - return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) } - } - - /** - * Serialize an object to bytes using the preferred serialization format version from the context. - * - * @param obj The object to be serialized. - * @param context A context that configures various parameters to serialization, including the serialization format version. - */ - fun serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes { - return withCurrentContext(context) { scheme.serialize(obj, context) } - } - - override fun toString(): String { - return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}" - } - - override fun equals(other: Any?): Boolean { - return other is CheckpointSerializationFactory && other.scheme == this.scheme - } - - override fun hashCode(): Int = scheme.hashCode() - - private val _currentContext = ThreadLocal() - - /** - * Change the current context inside the block to that supplied. - */ - fun withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T { - val priorContext = _currentContext.get() - if (context != null) _currentContext.set(context) - try { - return block() - } finally { - if (context != null) _currentContext.set(priorContext) - } - } - - companion object { - val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory - } + val CHECKPOINT_SERIALIZER get() = effectiveSerializationEnv.checkpointSerializer } @KeepForDJVM @DoNotImplement -interface CheckpointSerializationScheme { +interface CheckpointSerializer { @Throws(NotSerializableException::class) fun deserialize(byteSequence: ByteSequence, clazz: Class, context: CheckpointSerializationContext): T @@ -167,32 +104,36 @@ interface CheckpointSerializationContext { /* * Convenience extension method for deserializing a ByteSequence, utilising the default factory. */ -inline fun ByteSequence.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory, - context: CheckpointSerializationContext): T { - return serializationFactory.deserialize(this, T::class.java, context) +@JvmOverloads +inline fun ByteSequence.checkpointDeserialize( + context: CheckpointSerializationContext = effectiveSerializationEnv.checkpointContext): T { + return effectiveSerializationEnv.checkpointSerializer.deserialize(this, T::class.java, context) } /** * Convenience extension method for deserializing SerializedBytes with type matching, utilising the default factory. */ -inline fun SerializedBytes.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory, - context: CheckpointSerializationContext): T { - return serializationFactory.deserialize(this, T::class.java, context) +@JvmOverloads +inline fun SerializedBytes.checkpointDeserialize( + context: CheckpointSerializationContext = effectiveSerializationEnv.checkpointContext): T { + return effectiveSerializationEnv.checkpointSerializer.deserialize(this, T::class.java, context) } /** * Convenience extension method for deserializing a ByteArray, utilising the default factory. */ -inline fun ByteArray.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory, - context: CheckpointSerializationContext): T { +@JvmOverloads +inline fun ByteArray.checkpointDeserialize( + context: CheckpointSerializationContext = effectiveSerializationEnv.checkpointContext): T { require(isNotEmpty()) { "Empty bytes" } - return this.sequence().checkpointDeserialize(serializationFactory, context) + return this.sequence().checkpointDeserialize(context) } /** * Convenience extension method for serializing an object of type T, utilising the default factory. */ -fun T.checkpointSerialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory, - context: CheckpointSerializationContext): SerializedBytes { - return serializationFactory.serialize(this, context) +@JvmOverloads +fun T.checkpointSerialize( + context: CheckpointSerializationContext = effectiveSerializationEnv.checkpointContext): SerializedBytes { + return effectiveSerializationEnv.checkpointSerializer.serialize(this, context) } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt index 441cd52be4..b213b6322d 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt @@ -11,38 +11,63 @@ import net.corda.core.serialization.SerializationFactory @KeepForDJVM interface SerializationEnvironment { + + companion object { + fun with( + serializationFactory: SerializationFactory, + p2pContext: SerializationContext, + rpcServerContext: SerializationContext? = null, + rpcClientContext: SerializationContext? = null, + storageContext: SerializationContext? = null, + + checkpointContext: CheckpointSerializationContext? = null, + checkpointSerializer: CheckpointSerializer? = null + ): SerializationEnvironment = + SerializationEnvironmentImpl( + serializationFactory = serializationFactory, + p2pContext = p2pContext, + optionalRpcServerContext = rpcServerContext, + optionalRpcClientContext = rpcClientContext, + optionalStorageContext = storageContext, + optionalCheckpointContext = checkpointContext, + optionalCheckpointSerializer = checkpointSerializer + ) + } + val serializationFactory: SerializationFactory - val checkpointSerializationFactory: CheckpointSerializationFactory val p2pContext: SerializationContext val rpcServerContext: SerializationContext val rpcClientContext: SerializationContext val storageContext: SerializationContext + + val checkpointSerializer: CheckpointSerializer val checkpointContext: CheckpointSerializationContext } @KeepForDJVM -open class SerializationEnvironmentImpl( +private class SerializationEnvironmentImpl( override val serializationFactory: SerializationFactory, override val p2pContext: SerializationContext, - rpcServerContext: SerializationContext? = null, - rpcClientContext: SerializationContext? = null, - storageContext: SerializationContext? = null, - checkpointContext: CheckpointSerializationContext? = null, - checkpointSerializationFactory: CheckpointSerializationFactory? = null) : SerializationEnvironment { - // Those that are passed in as null are never inited: - override lateinit var rpcServerContext: SerializationContext - override lateinit var rpcClientContext: SerializationContext - override lateinit var storageContext: SerializationContext - override lateinit var checkpointContext: CheckpointSerializationContext - override lateinit var checkpointSerializationFactory: CheckpointSerializationFactory + private val optionalRpcServerContext: SerializationContext? = null, + private val optionalRpcClientContext: SerializationContext? = null, + private val optionalStorageContext: SerializationContext? = null, + private val optionalCheckpointContext: CheckpointSerializationContext? = null, + private val optionalCheckpointSerializer: CheckpointSerializer? = null) : SerializationEnvironment { - init { - rpcServerContext?.let { this.rpcServerContext = it } - rpcClientContext?.let { this.rpcClientContext = it } - storageContext?.let { this.storageContext = it } - checkpointContext?.let { this.checkpointContext = it } - checkpointSerializationFactory?.let { this.checkpointSerializationFactory = it } - } + override val rpcServerContext: SerializationContext get() = optionalRpcServerContext ?: + throw UnsupportedOperationException("RPC server serialization not supported in this environment") + + override val rpcClientContext: SerializationContext get() = optionalRpcClientContext ?: + throw UnsupportedOperationException("RPC client serialization not supported in this environment") + + override val storageContext: SerializationContext get() = optionalStorageContext ?: + throw UnsupportedOperationException("Storage serialization not supported in this environment") + + override val checkpointContext: CheckpointSerializationContext get() = optionalCheckpointContext ?: + throw UnsupportedOperationException("Checkpoint serialization not supported in this environment") + + override val checkpointSerializer: CheckpointSerializer get() = optionalCheckpointSerializer ?: + throw UnsupportedOperationException("Checkpoint serialization not supported in this environment") } private val _nodeSerializationEnv = SimpleToggleField("nodeSerializationEnv", true) diff --git a/core/src/test/java/net/corda/core/flows/SerializationApiInJavaTest.java b/core/src/test/java/net/corda/core/flows/SerializationApiInJavaTest.java index 55e66c1766..9f48a7ba0a 100644 --- a/core/src/test/java/net/corda/core/flows/SerializationApiInJavaTest.java +++ b/core/src/test/java/net/corda/core/flows/SerializationApiInJavaTest.java @@ -1,7 +1,5 @@ package net.corda.core.flows; -import net.corda.core.serialization.internal.CheckpointSerializationDefaults; -import net.corda.core.serialization.internal.CheckpointSerializationFactory; import net.corda.core.serialization.SerializationDefaults; import net.corda.core.serialization.SerializationFactory; import net.corda.testing.core.SerializationEnvironmentRule; @@ -32,12 +30,10 @@ public class SerializationApiInJavaTest { SerializationDefaults defaults = SerializationDefaults.INSTANCE; SerializationFactory factory = defaults.getSERIALIZATION_FACTORY(); - CheckpointSerializationDefaults checkpointDefaults = CheckpointSerializationDefaults.INSTANCE; - CheckpointSerializationFactory checkpointSerializationFactory = checkpointDefaults.getCHECKPOINT_SERIALIZATION_FACTORY(); serialize("hello", factory, defaults.getP2P_CONTEXT()); serialize("hello", factory, defaults.getRPC_SERVER_CONTEXT()); serialize("hello", factory, defaults.getRPC_CLIENT_CONTEXT()); serialize("hello", factory, defaults.getSTORAGE_CONTEXT()); - checkpointSerialize("hello", checkpointSerializationFactory, checkpointDefaults.getCHECKPOINT_CONTEXT()); + checkpointSerialize("hello"); } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/network/NetworkBootstrapper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/network/NetworkBootstrapper.kt index 49d0704c84..1747cc2333 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/network/NetworkBootstrapper.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/network/NetworkBootstrapper.kt @@ -14,7 +14,7 @@ import net.corda.core.node.services.AttachmentId import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.deserialize -import net.corda.core.serialization.internal.SerializationEnvironmentImpl +import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.core.serialization.internal._contextSerializationEnv import net.corda.core.utilities.days import net.corda.core.utilities.getOrThrow @@ -393,7 +393,7 @@ internal constructor(private val initSerEnv: Boolean, // We need to to set serialization env, because generation of parameters is run from Cordform. private fun initialiseSerialization() { - _contextSerializationEnv.set(SerializationEnvironmentImpl( + _contextSerializationEnv.set(SerializationEnvironment.with( SerializationFactoryImpl().apply { registerScheme(AMQPParametersSerializationScheme) }, diff --git a/node/src/main/kotlin/net/corda/node/internal/CheckpointVerifier.kt b/node/src/main/kotlin/net/corda/node/internal/CheckpointVerifier.kt index 53f22b3147..48b8c71971 100644 --- a/node/src/main/kotlin/net/corda/node/internal/CheckpointVerifier.kt +++ b/node/src/main/kotlin/net/corda/node/internal/CheckpointVerifier.kt @@ -4,7 +4,6 @@ import net.corda.core.cordapp.Cordapp import net.corda.core.crypto.SecureHash import net.corda.core.flows.FlowLogic import net.corda.core.node.ServiceHub -import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.internal.CheckpointSerializationDefaults import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.node.services.api.CheckpointStorage @@ -21,7 +20,7 @@ object CheckpointVerifier { */ fun verifyCheckpointsCompatible(checkpointStorage: CheckpointStorage, currentCordapps: List, platformVersion: Int, serviceHub: ServiceHub, tokenizableServices: List) { val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( - CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) + CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZER, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) ) checkpointStorage.getAllCheckpoints().forEach { (_, serializedCheckpoint) -> diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index f1c9140938..3db80bdf25 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -21,8 +21,7 @@ import net.corda.core.messaging.RPCOps import net.corda.core.node.NetworkParameters import net.corda.core.node.NodeInfo import net.corda.core.node.ServiceHub -import net.corda.core.serialization.internal.CheckpointSerializationFactory -import net.corda.core.serialization.internal.SerializationEnvironmentImpl +import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.contextLogger @@ -38,7 +37,7 @@ import net.corda.node.internal.security.RPCSecurityManagerImpl import net.corda.node.internal.security.RPCSecurityManagerWithAdditionalUser import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT -import net.corda.node.serialization.kryo.KryoSerializationScheme +import net.corda.node.serialization.kryo.KryoCheckpointSerializer import net.corda.node.services.Permissions import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.ServiceHubInternal @@ -470,17 +469,19 @@ open class Node(configuration: NodeConfiguration, private fun initialiseSerialization() { if (!initialiseSerialization) return val classloader = cordappLoader.appClassLoader - nodeSerializationEnv = SerializationEnvironmentImpl( + nodeSerializationEnv = SerializationEnvironment.with( SerializationFactoryImpl().apply { registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps)) registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps)) }, - checkpointSerializationFactory = CheckpointSerializationFactory(KryoSerializationScheme), p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader), + rpcClientContext = if (configuration.shouldInitCrashShell()) AMQP_RPC_CLIENT_CONTEXT.withClassLoader(classloader) else null, //even Shell embeded in the node connects via RPC to the node storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader), - checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader), - rpcClientContext = if (configuration.shouldInitCrashShell()) AMQP_RPC_CLIENT_CONTEXT.withClassLoader(classloader) else null) //even Shell embeded in the node connects via RPC to the node + + checkpointSerializer = KryoCheckpointSerializer, + checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader) + ) } /** Starts a blocking event loop for message dispatch. */ diff --git a/node/src/main/kotlin/net/corda/node/serialization/kryo/KryoSerializationScheme.kt b/node/src/main/kotlin/net/corda/node/serialization/kryo/KryoCheckpointSerializer.kt similarity index 97% rename from node/src/main/kotlin/net/corda/node/serialization/kryo/KryoSerializationScheme.kt rename to node/src/main/kotlin/net/corda/node/serialization/kryo/KryoCheckpointSerializer.kt index 22edf8258e..cd74dafb4c 100644 --- a/node/src/main/kotlin/net/corda/node/serialization/kryo/KryoSerializationScheme.kt +++ b/node/src/main/kotlin/net/corda/node/serialization/kryo/KryoCheckpointSerializer.kt @@ -12,10 +12,9 @@ import com.esotericsoftware.kryo.serializers.ClosureSerializer import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.* import net.corda.core.serialization.internal.CheckpointSerializationContext -import net.corda.core.serialization.internal.CheckpointSerializationScheme +import net.corda.core.serialization.internal.CheckpointSerializer import net.corda.core.utilities.ByteSequence import net.corda.serialization.internal.* -import java.security.PublicKey import java.util.concurrent.ConcurrentHashMap val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0)) @@ -31,7 +30,7 @@ private object AutoCloseableSerialisationDetector : Serializer() override fun read(kryo: Kryo, input: Input, type: Class) = throw IllegalStateException("Should not reach here!") } -object KryoSerializationScheme : CheckpointSerializationScheme { +object KryoCheckpointSerializer : CheckpointSerializer { private val kryoPoolsForContexts = ConcurrentHashMap, KryoPool>() private fun getPool(context: CheckpointSerializationContext): KryoPool { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt index a47799e0aa..60fb528bcb 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt @@ -127,7 +127,7 @@ class SingleThreadedStateMachineManager( override fun start(tokenizableServices: List) { checkQuasarJavaAgentPresence() val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( - CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) + CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZER, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) ) this.checkpointSerializationContext = checkpointSerializationContext this.actionExecutor = makeActionExecutor(checkpointSerializationContext) diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/RaftTransactionCommitLog.kt b/node/src/main/kotlin/net/corda/node/services/transactions/RaftTransactionCommitLog.kt index 72fa52bdc1..c35ae146ab 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/RaftTransactionCommitLog.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/RaftTransactionCommitLog.kt @@ -22,7 +22,7 @@ import net.corda.core.internal.notary.isConsumedByTheSameTx import net.corda.core.internal.notary.validateTimeWindow import net.corda.core.serialization.* import net.corda.core.serialization.internal.CheckpointSerializationDefaults -import net.corda.core.serialization.internal.CheckpointSerializationFactory + import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.contextLogger @@ -201,7 +201,7 @@ class RaftTransactionCommitLog( class CordaKryoSerializer : TypeSerializer { private val context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY) - private val factory = CheckpointSerializationFactory.defaultFactory + private val checkpointSerializer = CheckpointSerializationDefaults.CHECKPOINT_SERIALIZER override fun write(obj: T, buffer: BufferOutput<*>, serializer: Serializer) { val serialized = obj.checkpointSerialize(context = context) @@ -213,7 +213,7 @@ class RaftTransactionCommitLog( val size = buffer.readInt() val serialized = ByteArray(size) buffer.read(serialized) - return factory.deserialize(ByteSequence.of(serialized), type, context) + return checkpointSerializer.deserialize(ByteSequence.of(serialized), type, context) } } } diff --git a/node/src/test/kotlin/net/corda/node/serialization/kryo/KryoTests.kt b/node/src/test/kotlin/net/corda/node/serialization/kryo/KryoTests.kt index 5598f38f67..abfec25bc2 100644 --- a/node/src/test/kotlin/net/corda/node/serialization/kryo/KryoTests.kt +++ b/node/src/test/kotlin/net/corda/node/serialization/kryo/KryoTests.kt @@ -13,7 +13,6 @@ import net.corda.core.crypto.* import net.corda.core.internal.FetchDataFlow import net.corda.core.serialization.* import net.corda.core.serialization.internal.CheckpointSerializationContext -import net.corda.core.serialization.internal.CheckpointSerializationFactory import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.utilities.ByteSequence @@ -23,11 +22,13 @@ import net.corda.node.services.persistence.NodeAttachmentService import net.corda.serialization.internal.* import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.TestIdentity +import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule import net.corda.testing.internal.rigorousMock import org.assertj.core.api.Assertions.* import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertEquals import org.junit.Before +import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized @@ -48,12 +49,12 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { fun compression() = arrayOf(null) + CordaSerializationEncoding.values() } - private lateinit var factory: CheckpointSerializationFactory + @get:Rule + val serializationRule = CheckpointSerializationEnvironmentRule() private lateinit var context: CheckpointSerializationContext @Before fun setup() { - factory = CheckpointSerializationFactory(KryoSerializationScheme) context = CheckpointSerializationContextImpl( javaClass.classLoader, AllWhitelist, @@ -69,15 +70,15 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { fun `simple data class`() { val birthday = Instant.parse("1984-04-17T00:30:00.00Z") val mike = Person("mike", birthday) - val bits = mike.checkpointSerialize(factory, context) - assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("mike", birthday)) + val bits = mike.checkpointSerialize(context) + assertThat(bits.checkpointDeserialize(context)).isEqualTo(Person("mike", birthday)) } @Test fun `null values`() { val bob = Person("bob", null) - val bits = bob.checkpointSerialize(factory, context) - assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("bob", null)) + val bits = bob.checkpointSerialize(context) + assertThat(bits.checkpointDeserialize(context)).isEqualTo(Person("bob", null)) } @Test @@ -85,10 +86,10 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { val noReferencesContext = context.withoutReferences() val obj : ByteSequence = Ints.toByteArray(0x01234567).sequence() val originalList : ArrayList = ArrayList().apply { this += obj } - val deserialisedList = originalList.checkpointSerialize(factory, noReferencesContext).checkpointDeserialize(factory, noReferencesContext) + val deserialisedList = originalList.checkpointSerialize(noReferencesContext).checkpointDeserialize(noReferencesContext) originalList += obj deserialisedList += obj - assertThat(deserialisedList.checkpointSerialize(factory, noReferencesContext)).isEqualTo(originalList.checkpointSerialize(factory, noReferencesContext)) + assertThat(deserialisedList.checkpointSerialize(noReferencesContext)).isEqualTo(originalList.checkpointSerialize(noReferencesContext)) } @Test @@ -105,14 +106,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { this += instant this += instant } - assertThat(listWithSameInstances.checkpointSerialize(factory, noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(factory, noReferencesContext)) + assertThat(listWithSameInstances.checkpointSerialize(noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(noReferencesContext)) } @Test fun `cyclic object graph`() { val cyclic = Cyclic(3) - val bits = cyclic.checkpointSerialize(factory, context) - assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(cyclic) + val bits = cyclic.checkpointSerialize(context) + assertThat(bits.checkpointDeserialize(context)).isEqualTo(cyclic) } @Test @@ -124,7 +125,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { signature.verify(bitsToSign) assertThatThrownBy { signature.verify(wrongBits) } - val deserialisedKeyPair = keyPair.checkpointSerialize(factory, context).checkpointDeserialize(factory, context) + val deserialisedKeyPair = keyPair.checkpointSerialize(context).checkpointDeserialize(context) val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign) deserialisedSignature.verify(bitsToSign) assertThatThrownBy { deserialisedSignature.verify(wrongBits) } @@ -132,28 +133,28 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { @Test fun `write and read Kotlin object singleton`() { - val serialised = TestSingleton.checkpointSerialize(factory, context) - val deserialised = serialised.checkpointDeserialize(factory, context) + val serialised = TestSingleton.checkpointSerialize(context) + val deserialised = serialised.checkpointDeserialize(context) assertThat(deserialised).isSameAs(TestSingleton) } @Test fun `check Kotlin EmptyList can be serialised`() { - val deserialisedList: List = emptyList().checkpointSerialize(factory, context).checkpointDeserialize(factory, context) + val deserialisedList: List = emptyList().checkpointSerialize(context).checkpointDeserialize(context) assertEquals(0, deserialisedList.size) assertEquals(Collections.emptyList().javaClass, deserialisedList.javaClass) } @Test fun `check Kotlin EmptySet can be serialised`() { - val deserialisedSet: Set = emptySet().checkpointSerialize(factory, context).checkpointDeserialize(factory, context) + val deserialisedSet: Set = emptySet().checkpointSerialize(context).checkpointDeserialize(context) assertEquals(0, deserialisedSet.size) assertEquals(Collections.emptySet().javaClass, deserialisedSet.javaClass) } @Test fun `check Kotlin EmptyMap can be serialised`() { - val deserialisedMap: Map = emptyMap().checkpointSerialize(factory, context).checkpointDeserialize(factory, context) + val deserialisedMap: Map = emptyMap().checkpointSerialize(context).checkpointDeserialize(context) assertEquals(0, deserialisedMap.size) assertEquals(Collections.emptyMap().javaClass, deserialisedMap.javaClass) } @@ -161,7 +162,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { @Test fun `InputStream serialisation`() { val rubbish = ByteArray(12345) { (it * it * 0.12345).toByte() } - val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(factory, context).checkpointDeserialize(factory, context) + val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(context).checkpointDeserialize(context) for (i in 0..12344) { assertEquals(rubbish[i], readRubbishStream.read().toByte()) } @@ -171,7 +172,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { @Test fun `InputStream serialisation does not write trailing garbage`() { val byteArrays = listOf("123", "456").map { it.toByteArray() } - val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(factory, context).checkpointDeserialize(factory, context).iterator() + val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(context).checkpointDeserialize(context).iterator() byteArrays.forEach { assertArrayEquals(it, streams.next().readBytes()) } assertFalse(streams.hasNext()) } @@ -182,8 +183,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { val testBytes = testString.toByteArray() val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID)) - val serializedMetaData = meta.checkpointSerialize(factory, context).bytes - val meta2 = serializedMetaData.checkpointDeserialize(factory, context) + val serializedMetaData = meta.checkpointSerialize(context).bytes + val meta2 = serializedMetaData.checkpointDeserialize(context) assertEquals(meta2, meta) } @@ -191,7 +192,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { fun `serialize - deserialize Logger`() { val storageContext: CheckpointSerializationContext = context val logger = LoggerFactory.getLogger("aName") - val logger2 = logger.checkpointSerialize(factory, storageContext).checkpointDeserialize(factory, storageContext) + val logger2 = logger.checkpointSerialize(storageContext).checkpointDeserialize(storageContext) assertEquals(logger.name, logger2.name) assertTrue(logger === logger2) } @@ -203,7 +204,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { SecureHash.sha256(rubbish), rubbish.size, rubbish.inputStream() - ).checkpointSerialize(factory, context).checkpointDeserialize(factory, context) + ).checkpointSerialize(context).checkpointDeserialize(context) for (i in 0..12344) { assertEquals(rubbish[i], readRubbishStream.read().toByte()) } @@ -230,8 +231,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 )) - val serializedBytes = expected.checkpointSerialize(factory, context) - val actual = serializedBytes.checkpointDeserialize(factory, context) + val serializedBytes = expected.checkpointSerialize(context) + val actual = serializedBytes.checkpointDeserialize(context) assertEquals(expected, actual) } @@ -278,14 +279,13 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { } } Tmp() - val factory = CheckpointSerializationFactory(KryoSerializationScheme) val context = CheckpointSerializationContextImpl( javaClass.classLoader, AllWhitelist, emptyMap(), true, null) - pt.checkpointSerialize(factory, context) + pt.checkpointSerialize(context) } @Test @@ -293,7 +293,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { val exception = IllegalArgumentException("fooBar") val toBeSuppressedOnSenderSide = IllegalStateException("bazz1") exception.addSuppressed(toBeSuppressedOnSenderSide) - val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context) + val exception2 = exception.checkpointSerialize(context).checkpointDeserialize(context) assertEquals(exception.message, exception2.message) assertEquals(1, exception2.suppressed.size) @@ -308,7 +308,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { @Test fun `serialize - deserialize Exception no suppressed`() { val exception = IllegalArgumentException("fooBar") - val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context) + val exception2 = exception.checkpointSerialize(context).checkpointDeserialize(context) assertEquals(exception.message, exception2.message) assertEquals(0, exception2.suppressed.size) @@ -322,7 +322,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { fun `serialize - deserialize HashNotFound`() { val randomHash = SecureHash.randomSHA256() val exception = FetchDataFlow.HashNotFound(randomHash) - val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context) + val exception2 = exception.checkpointSerialize(context).checkpointDeserialize(context) assertEquals(randomHash, exception2.requested) } @@ -330,17 +330,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { fun `compression has the desired effect`() { compression ?: return val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it } - val compressed = data.checkpointSerialize(factory, context) + val compressed = data.checkpointSerialize(context) assertEquals(.5, compressed.size.toDouble() / data.size, .03) - assertArrayEquals(data, compressed.checkpointDeserialize(factory, context)) + assertArrayEquals(data, compressed.checkpointDeserialize(context)) } @Test fun `a particular encoding can be banned for deserialization`() { compression ?: return doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression) - val compressed = "whatever".checkpointSerialize(factory, context) - catchThrowable { compressed.checkpointDeserialize(factory, context) }.run { + val compressed = "whatever".checkpointSerialize(context) + catchThrowable { compressed.checkpointDeserialize(context) }.run { assertSame(KryoException::class.java, javaClass) assertEquals(encodingNotPermittedFormat.format(compression), message) } @@ -351,8 +351,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { class Holder(val holder: ByteArray) val obj = Holder(ByteArray(20000)) - val uncompressedSize = obj.checkpointSerialize(factory, context.withEncoding(null)).size - val compressedSize = obj.checkpointSerialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size + val uncompressedSize = obj.checkpointSerialize(context.withEncoding(null)).size + val compressedSize = obj.checkpointSerialize(context.withEncoding(CordaSerializationEncoding.SNAPPY)).size // If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade. assertEquals(20222, uncompressedSize) assertEquals(1111, compressedSize) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializeAsTokenContextImpl.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializeAsTokenContextImpl.kt index 785ce47597..025e27a38a 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializeAsTokenContextImpl.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializeAsTokenContextImpl.kt @@ -5,7 +5,7 @@ import net.corda.core.DeleteForDJVM import net.corda.core.node.ServiceHub import net.corda.core.serialization.* import net.corda.core.serialization.internal.CheckpointSerializationContext -import net.corda.core.serialization.internal.CheckpointSerializationFactory +import net.corda.core.serialization.internal.CheckpointSerializer val serializationContextKey = SerializeAsTokenContext::class.java @@ -70,8 +70,8 @@ class SerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: Ser */ @DeleteForDJVM class CheckpointSerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: SerializeAsTokenContext.() -> Unit) : SerializeAsTokenContext { - constructor(toBeTokenized: Any, serializationFactory: CheckpointSerializationFactory, context: CheckpointSerializationContext, serviceHub: ServiceHub) : this(serviceHub, { - serializationFactory.serialize(toBeTokenized, context.withTokenContext(this)) + constructor(toBeTokenized: Any, serializer: CheckpointSerializer, context: CheckpointSerializationContext, serviceHub: ServiceHub) : this(serviceHub, { + serializer.serialize(toBeTokenized, context.withTokenContext(this)) }) private val classNameToSingleton = mutableMapOf() diff --git a/serialization/src/test/java/net/corda/serialization/internal/LambdaCheckpointSerializationTest.java b/serialization/src/test/java/net/corda/serialization/internal/LambdaCheckpointSerializationTest.java index feab89ad92..0c9c9f2b5a 100644 --- a/serialization/src/test/java/net/corda/serialization/internal/LambdaCheckpointSerializationTest.java +++ b/serialization/src/test/java/net/corda/serialization/internal/LambdaCheckpointSerializationTest.java @@ -2,14 +2,14 @@ package net.corda.serialization.internal; import net.corda.core.serialization.*; import net.corda.core.serialization.internal.CheckpointSerializationContext; -import net.corda.core.serialization.internal.CheckpointSerializationFactory; +import net.corda.core.serialization.internal.CheckpointSerializer; import net.corda.node.serialization.kryo.CordaClosureSerializer; -import net.corda.testing.core.SerializationEnvironmentRule; import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule; import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import java.io.NotSerializableException; import java.io.Serializable; import java.util.Collections; import java.util.concurrent.Callable; @@ -23,12 +23,11 @@ public final class LambdaCheckpointSerializationTest { public final CheckpointSerializationEnvironmentRule testCheckpointSerialization = new CheckpointSerializationEnvironmentRule(); - private CheckpointSerializationFactory factory; private CheckpointSerializationContext context; + private CheckpointSerializer serializer; @Before public void setup() { - factory = testCheckpointSerialization.getCheckpointSerializationFactory(); context = new CheckpointSerializationContextImpl( getClass().getClassLoader(), AllWhitelist.INSTANCE, @@ -36,6 +35,8 @@ public final class LambdaCheckpointSerializationTest { true, null ); + + serializer = testCheckpointSerialization.getCheckpointSerializer(); } @Test @@ -63,11 +64,11 @@ public final class LambdaCheckpointSerializationTest { assertThat(throwable).hasMessage(CordaClosureSerializer.ERROR_MESSAGE); } - private SerializedBytes serialize(final T target) { - return factory.serialize(target, context); + private SerializedBytes serialize(final T target) throws NotSerializableException { + return serializer.serialize(target, context); } - private T deserialize(final SerializedBytes bytes, final Class type) { - return factory.deserialize(bytes, type, context); + private T deserialize(final SerializedBytes bytes, final Class type) throws NotSerializableException { + return serializer.deserialize(bytes, type, context); } } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/ContractAttachmentSerializerTest.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/ContractAttachmentSerializerTest.kt index 73b799217d..e3426a1fd9 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/ContractAttachmentSerializerTest.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/ContractAttachmentSerializerTest.kt @@ -4,11 +4,9 @@ import net.corda.core.contracts.ContractAttachment import net.corda.core.identity.CordaX500Name import net.corda.core.serialization.* import net.corda.core.serialization.internal.CheckpointSerializationContext -import net.corda.core.serialization.internal.CheckpointSerializationFactory import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.core.serialization.internal.checkpointSerialize import net.corda.testing.contracts.DummyContract -import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule import net.corda.testing.internal.rigorousMock import net.corda.testing.node.MockServices @@ -27,24 +25,25 @@ class ContractAttachmentSerializerTest { @JvmField val testCheckpointSerialization = CheckpointSerializationEnvironmentRule() - private lateinit var factory: CheckpointSerializationFactory - private lateinit var context: CheckpointSerializationContext private lateinit var contextWithToken: CheckpointSerializationContext private val mockServices = MockServices(emptyList(), CordaX500Name("MegaCorp", "London", "GB"), rigorousMock()) @Before fun setup() { - factory = testCheckpointSerialization.checkpointSerializationFactory - context = testCheckpointSerialization.checkpointSerializationContext - contextWithToken = context.withTokenContext(CheckpointSerializeAsTokenContextImpl(Any(), factory, context, mockServices)) + contextWithToken = testCheckpointSerialization.checkpointSerializationContext.withTokenContext( + CheckpointSerializeAsTokenContextImpl( + Any(), + testCheckpointSerialization.checkpointSerializer, + testCheckpointSerialization.checkpointSerializationContext, + mockServices)) } @Test fun `write contract attachment and read it back`() { val contractAttachment = ContractAttachment(GeneratedAttachment(EMPTY_BYTE_ARRAY), DummyContract.PROGRAM_ID) // no token context so will serialize the whole attachment - val serialized = contractAttachment.checkpointSerialize(factory, context) - val deserialized = serialized.checkpointDeserialize(factory, context) + val serialized = contractAttachment.checkpointSerialize() + val deserialized = serialized.checkpointDeserialize() assertEquals(contractAttachment.id, deserialized.attachment.id) assertEquals(contractAttachment.contract, deserialized.contract) @@ -59,8 +58,8 @@ class ContractAttachmentSerializerTest { mockServices.attachments.importAttachment(attachment.open(), "test", null) val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) - val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken) - val deserialized = serialized.checkpointDeserialize(factory, contextWithToken) + val serialized = contractAttachment.checkpointSerialize(contextWithToken) + val deserialized = serialized.checkpointDeserialize(contextWithToken) assertEquals(contractAttachment.id, deserialized.attachment.id) assertEquals(contractAttachment.contract, deserialized.contract) @@ -76,7 +75,7 @@ class ContractAttachmentSerializerTest { mockServices.attachments.importAttachment(attachment.open(), "test", null) val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) - val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken) + val serialized = contractAttachment.checkpointSerialize(contextWithToken) assertThat(serialized.size).isLessThan(largeAttachmentSize) } @@ -88,8 +87,8 @@ class ContractAttachmentSerializerTest { // don't importAttachment in mockService val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) - val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken) - val deserialized = serialized.checkpointDeserialize(factory, contextWithToken) + val serialized = contractAttachment.checkpointSerialize(contextWithToken) + val deserialized = serialized.checkpointDeserialize(contextWithToken) assertThatThrownBy { deserialized.attachment.open() }.isInstanceOf(MissingAttachmentsException::class.java) } @@ -100,8 +99,8 @@ class ContractAttachmentSerializerTest { // don't importAttachment in mockService val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) - val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken) - serialized.checkpointDeserialize(factory, contextWithToken) + val serialized = contractAttachment.checkpointSerialize(contextWithToken) + serialized.checkpointDeserialize(contextWithToken) // MissingAttachmentsException thrown if we try to open attachment } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/SerializationTokenTest.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/SerializationTokenTest.kt index 7f2bad6854..49e5d1f2ed 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/SerializationTokenTest.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/SerializationTokenTest.kt @@ -5,7 +5,6 @@ import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.io.Output import net.corda.core.serialization.* import net.corda.core.serialization.internal.CheckpointSerializationContext -import net.corda.core.serialization.internal.CheckpointSerializationFactory import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.utilities.OpaqueBytes @@ -14,7 +13,6 @@ import net.corda.node.serialization.kryo.CordaKryo import net.corda.node.serialization.kryo.DefaultKryoCustomizer import net.corda.node.serialization.kryo.kryoMagic import net.corda.testing.internal.rigorousMock -import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule import org.assertj.core.api.Assertions.assertThat import org.junit.Before @@ -28,12 +26,10 @@ class SerializationTokenTest { @JvmField val testCheckpointSerialization = CheckpointSerializationEnvironmentRule() - private lateinit var factory: CheckpointSerializationFactory private lateinit var context: CheckpointSerializationContext @Before fun setup() { - factory = testCheckpointSerialization.checkpointSerializationFactory context = testCheckpointSerialization.checkpointSerializationContext.withWhitelisted(SingletonSerializationToken::class.java) } @@ -49,16 +45,16 @@ class SerializationTokenTest { override fun equals(other: Any?) = other is LargeTokenizable && other.bytes.size == this.bytes.size } - private fun serializeAsTokenContext(toBeTokenized: Any) = CheckpointSerializeAsTokenContextImpl(toBeTokenized, factory, context, rigorousMock()) + private fun serializeAsTokenContext(toBeTokenized: Any) = CheckpointSerializeAsTokenContextImpl(toBeTokenized, testCheckpointSerialization.checkpointSerializer, context, rigorousMock()) @Test fun `write token and read tokenizable`() { val tokenizableBefore = LargeTokenizable() val context = serializeAsTokenContext(tokenizableBefore) val testContext = this.context.withTokenContext(context) - val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext) + val serializedBytes = tokenizableBefore.checkpointSerialize(testContext) assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes) - val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext) + val tokenizableAfter = serializedBytes.checkpointDeserialize(testContext) assertThat(tokenizableAfter).isSameAs(tokenizableBefore) } @@ -69,8 +65,8 @@ class SerializationTokenTest { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(tokenizableBefore) val testContext = this.context.withTokenContext(context) - val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext) - val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext) + val serializedBytes = tokenizableBefore.checkpointSerialize(testContext) + val tokenizableAfter = serializedBytes.checkpointDeserialize(testContext) assertThat(tokenizableAfter).isSameAs(tokenizableBefore) } @@ -79,7 +75,7 @@ class SerializationTokenTest { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(emptyList()) val testContext = this.context.withTokenContext(context) - tokenizableBefore.checkpointSerialize(factory, testContext) + tokenizableBefore.checkpointSerialize(testContext) } @Test(expected = UnsupportedOperationException::class) @@ -87,14 +83,14 @@ class SerializationTokenTest { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(emptyList()) val testContext = this.context.withTokenContext(context) - val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList())).checkpointSerialize(factory, testContext) - serializedBytes.checkpointDeserialize(factory, testContext) + val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList())).checkpointSerialize(testContext) + serializedBytes.checkpointDeserialize(testContext) } @Test(expected = KryoException::class) fun `no context set`() { val tokenizableBefore = UnitSerializeAsToken() - tokenizableBefore.checkpointSerialize(factory, context) + tokenizableBefore.checkpointSerialize(context) } @Test(expected = KryoException::class) @@ -112,7 +108,7 @@ class SerializationTokenTest { kryo.writeObject(it, emptyList()) } val serializedBytes = SerializedBytes(stream.toByteArray()) - serializedBytes.checkpointDeserialize(factory, testContext) + serializedBytes.checkpointDeserialize(testContext) } private class WrongTypeSerializeAsToken : SerializeAsToken { @@ -128,7 +124,7 @@ class SerializationTokenTest { val tokenizableBefore = WrongTypeSerializeAsToken() val context = serializeAsTokenContext(tokenizableBefore) val testContext = this.context.withTokenContext(context) - val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext) - serializedBytes.checkpointDeserialize(factory, testContext) + val serializedBytes = tokenizableBefore.checkpointSerialize(testContext) + serializedBytes.checkpointDeserialize(testContext) } } diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/core/SerializationTestHelpers.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/core/SerializationTestHelpers.kt index 514b23a855..526f1251b6 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/core/SerializationTestHelpers.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/core/SerializationTestHelpers.kt @@ -3,9 +3,7 @@ package net.corda.testing.core import com.nhaarman.mockito_kotlin.any import com.nhaarman.mockito_kotlin.doAnswer import com.nhaarman.mockito_kotlin.whenever -import net.corda.core.DoNotImplement import net.corda.core.internal.staticField -import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.core.serialization.internal.effectiveSerializationEnv import net.corda.testing.common.internal.asContextEnv @@ -40,7 +38,7 @@ class SerializationEnvironmentRule(private val inheritable: Boolean = false) : T /** Do not call, instead use [SerializationEnvironmentRule] as a [org.junit.Rule]. */ fun run(taskLabel: String, task: (SerializationEnvironment) -> T): T { - return SerializationEnvironmentRule().apply { init(taskLabel) }.runTask(task) + return SerializationEnvironmentRule().apply { init() }.runTask(task) } } @@ -48,14 +46,14 @@ class SerializationEnvironmentRule(private val inheritable: Boolean = false) : T val serializationFactory get() = env.serializationFactory override fun apply(base: Statement, description: Description): Statement { - init(description.toString()) + init() return object : Statement() { override fun evaluate() = runTask { base.evaluate() } } } - private fun init(envLabel: String) { - env = createTestSerializationEnv(envLabel) + private fun init() { + env = createTestSerializationEnv() } private fun runTask(task: (SerializationEnvironment) -> T): T { diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/core/internal/CheckpointSerializationTestHelpers.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/core/internal/CheckpointSerializationTestHelpers.kt index eb92d12cf6..69b634e107 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/core/internal/CheckpointSerializationTestHelpers.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/core/internal/CheckpointSerializationTestHelpers.kt @@ -39,7 +39,7 @@ class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean = /** Do not call, instead use [SerializationEnvironmentRule] as a [org.junit.Rule]. */ fun run(taskLabel: String, task: (SerializationEnvironment) -> T): T { - return CheckpointSerializationEnvironmentRule().apply { init(taskLabel) }.runTask(task) + return CheckpointSerializationEnvironmentRule().apply { init() }.runTask(task) } } @@ -47,14 +47,14 @@ class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean = private lateinit var env: SerializationEnvironment override fun apply(base: Statement, description: Description): Statement { - init(description.toString()) + init() return object : Statement() { override fun evaluate() = runTask { base.evaluate() } } } - private fun init(envLabel: String) { - env = createTestSerializationEnv(envLabel) + private fun init() { + env = createTestSerializationEnv() } private fun runTask(task: (SerializationEnvironment) -> T): T { @@ -65,7 +65,6 @@ class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean = } } - val checkpointSerializationFactory get() = env.checkpointSerializationFactory val checkpointSerializationContext get() = env.checkpointContext - + val checkpointSerializer get() = env.checkpointSerializer } diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalSerializationTestHelpers.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalSerializationTestHelpers.kt index 53bee6f798..3257cca802 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalSerializationTestHelpers.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalSerializationTestHelpers.kt @@ -4,11 +4,10 @@ import com.nhaarman.mockito_kotlin.doNothing import com.nhaarman.mockito_kotlin.whenever import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme import net.corda.core.DoNotImplement -import net.corda.core.serialization.internal.CheckpointSerializationFactory import net.corda.core.serialization.internal.* import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT -import net.corda.node.serialization.kryo.KryoSerializationScheme +import net.corda.node.serialization.kryo.KryoCheckpointSerializer import net.corda.serialization.internal.* import net.corda.testing.core.SerializationEnvironmentRule import java.util.concurrent.ConcurrentHashMap @@ -30,22 +29,20 @@ fun withoutTestSerialization(callable: () -> T): T { // TODO: Delete this, s } } -internal fun createTestSerializationEnv(label: String): SerializationEnvironmentImpl { +internal fun createTestSerializationEnv(): SerializationEnvironment { val factory = SerializationFactoryImpl().apply { registerScheme(AMQPClientSerializationScheme(emptyList())) registerScheme(AMQPServerSerializationScheme(emptyList())) } - return object : SerializationEnvironmentImpl( + return SerializationEnvironment.with( factory, AMQP_P2P_CONTEXT, AMQP_RPC_SERVER_CONTEXT, AMQP_RPC_CLIENT_CONTEXT, AMQP_STORAGE_CONTEXT, KRYO_CHECKPOINT_CONTEXT, - CheckpointSerializationFactory(KryoSerializationScheme) - ) { - override fun toString() = "testSerializationEnv($label)" - } + KryoCheckpointSerializer + ) } /** @@ -54,7 +51,7 @@ internal fun createTestSerializationEnv(label: String): SerializationEnvironment */ fun setGlobalSerialization(armed: Boolean): GlobalSerializationEnvironment { return if (armed) { - object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv("") { + object : GlobalSerializationEnvironment, SerializationEnvironment by createTestSerializationEnv() { override fun unset() { _globalSerializationEnv.set(null) inVMExecutors.remove(this) diff --git a/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt b/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt index 570a9cbe2e..e65a7441be 100644 --- a/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt +++ b/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt @@ -8,11 +8,10 @@ import net.corda.cliutils.CordaCliWrapper import net.corda.cliutils.ExitCodes import net.corda.cliutils.start import net.corda.core.internal.isRegularFile -import net.corda.core.internal.rootMessage import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.deserialize -import net.corda.core.serialization.internal.SerializationEnvironmentImpl +import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.core.serialization.internal._contextSerializationEnv import net.corda.core.utilities.base64ToByteArray import net.corda.core.utilities.hexToByteArray @@ -22,7 +21,6 @@ import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme import net.corda.serialization.internal.amqp.DeserializationInput import net.corda.serialization.internal.amqp.amqpMagic import org.slf4j.event.Level -import picocli.CommandLine import picocli.CommandLine.* import java.io.PrintStream import java.net.MalformedURLException @@ -128,7 +126,7 @@ class BlobInspector : CordaCliWrapper("blob-inspector", "Convert AMQP serialised private fun initialiseSerialization() { // Deserialise with the lenient carpenter as we only care for the AMQP field getters - _contextSerializationEnv.set(SerializationEnvironmentImpl( + _contextSerializationEnv.set(SerializationEnvironment.with( SerializationFactoryImpl().apply { registerScheme(AMQPInspectorSerializationScheme) }, diff --git a/tools/demobench/src/main/kotlin/net/corda/demobench/DemoBench.kt b/tools/demobench/src/main/kotlin/net/corda/demobench/DemoBench.kt index 86430dadf3..dfecae05ac 100644 --- a/tools/demobench/src/main/kotlin/net/corda/demobench/DemoBench.kt +++ b/tools/demobench/src/main/kotlin/net/corda/demobench/DemoBench.kt @@ -2,7 +2,7 @@ package net.corda.demobench import javafx.scene.image.Image import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme -import net.corda.core.serialization.internal.SerializationEnvironmentImpl +import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.demobench.views.DemoBenchView import net.corda.serialization.internal.AMQP_P2P_CONTEXT @@ -56,7 +56,7 @@ class DemoBench : App(DemoBenchView::class) { } private fun initialiseSerialization() { - nodeSerializationEnv = SerializationEnvironmentImpl( + nodeSerializationEnv = SerializationEnvironment.with( SerializationFactoryImpl().apply { registerScheme(AMQPClientSerializationScheme(emptyList())) }, diff --git a/tools/network-bootstrapper/src/main/kotlin/net/corda/bootstrapper/serialization/SerializationHelper.kt b/tools/network-bootstrapper/src/main/kotlin/net/corda/bootstrapper/serialization/SerializationHelper.kt index ce978e4131..d0efa8d492 100644 --- a/tools/network-bootstrapper/src/main/kotlin/net/corda/bootstrapper/serialization/SerializationHelper.kt +++ b/tools/network-bootstrapper/src/main/kotlin/net/corda/bootstrapper/serialization/SerializationHelper.kt @@ -1,9 +1,10 @@ package net.corda.bootstrapper.serialization -import net.corda.core.serialization.internal.SerializationEnvironmentImpl +import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT +import net.corda.node.serialization.kryo.KryoCheckpointSerializer import net.corda.serialization.internal.AMQP_P2P_CONTEXT import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT import net.corda.serialization.internal.SerializationFactoryImpl @@ -14,14 +15,16 @@ class SerializationEngine { synchronized(this) { if (nodeSerializationEnv == null) { val classloader = this::class.java.classLoader - nodeSerializationEnv = SerializationEnvironmentImpl( + nodeSerializationEnv = SerializationEnvironment.with( SerializationFactoryImpl().apply { registerScheme(AMQPServerSerializationScheme(emptyList())) }, p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), rpcServerContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader), - checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader) + + checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader), + checkpointSerializer = KryoCheckpointSerializer ) } }