From 98c92ef16fd049fd2b18f4be6e3de8c42e0217e9 Mon Sep 17 00:00:00 2001 From: Dominic Fox <40790090+distributedleetravis@users.noreply.github.com> Date: Wed, 19 Sep 2018 14:23:29 +0100 Subject: [PATCH] CORDA-1391: Separate out Checkpoint serialization (#3922) * Separate out Checkpoint serialization * Update kdocs * Rename checkpoint serialization extension methods * Fix bungled rename * Limit API changes * Simplify CheckpointSerializationFactory * Add CheckpointSerializationScheme to API checker * CheckpointSerializationScheme should not be implemented * Move checkpoint serialisation to internal package * Remove CheckpointSerializationScheme from api-current * Quarantine internal classes * Remove checkpoint context from public API * Remove checkpoint context from public API * Fix test failures * Completely decouple SerializationTestHelpers and CheckpointSerializationTestHelpers * Remove CHECKPOINT use case * Remove stray reference to checkpoint use case * Fix broken test --- .ci/api-current.txt | 4 - core-deterministic/build.gradle | 1 + .../CheckpointSerializationFactory.kt | 74 +++++++ .../core/serialization/SerializationAPI.kt | 9 +- .../internal/CheckpointSerializationAPI.kt | 198 ++++++++++++++++++ .../internal/SerializationEnvironment.kt | 10 +- .../flows/SerializationApiInJavaTest.java | 8 +- .../corda/core/utilities/KotlinUtilsTest.kt | 16 +- .../corda/node/internal/CheckpointVerifier.kt | 11 +- .../kotlin/net/corda/node/internal/Node.kt | 5 +- .../serialization/kryo/CordaClassResolver.kt | 4 +- .../net/corda/node/serialization/kryo/Kryo.kt | 12 +- .../kryo/KryoSerializationScheme.kt | 67 +++--- .../kryo/KryoServerSerializationScheme.kt | 14 -- .../statemachine/ActionExecutorImpl.kt | 10 +- .../statemachine/FlowStateMachineImpl.kt | 8 +- .../SingleThreadedStateMachineManager.kt | 22 +- ...FiberDeserializationCheckingInterceptor.kt | 10 +- .../transactions/RaftTransactionCommitLog.kt | 14 +- .../node/serialization/kryo/KryoTests.kt | 97 ++++----- .../persistence/DBCheckpointStorageTests.kt | 8 +- .../internal/CheckpointSerializationScheme.kt | 49 +++++ .../internal/SerializeAsTokenContextImpl.kt | 55 ++++- .../internal/UseCaseAwareness.kt | 8 + .../internal/amqp/AMQPSerializationScheme.kt | 2 - .../amqp/custom/PrivateKeySerializer.kt | 5 +- .../ForbiddenLambdaSerializationTests.java | 5 +- .../LambdaCheckpointSerializationTest.java | 23 +- .../ContractAttachmentSerializerTest.kt | 38 ++-- .../internal/CordaClassResolverTests.kt | 7 +- .../internal/PrivateKeySerializationTest.kt | 6 +- .../internal/SerializationTokenTest.kt | 41 ++-- .../testing/core/SerializationTestHelpers.kt | 2 +- .../CheckpointSerializationTestHelpers.kt | 71 +++++++ .../InternalSerializationTestHelpers.kt | 8 +- .../net/corda/blobinspector/BlobInspector.kt | 5 +- .../serialization/SerializationHelper.kt | 3 +- 37 files changed, 677 insertions(+), 253 deletions(-) create mode 100644 core-deterministic/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationFactory.kt create mode 100644 core/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationAPI.kt delete mode 100644 node/src/main/kotlin/net/corda/node/serialization/kryo/KryoServerSerializationScheme.kt create mode 100644 serialization/src/main/kotlin/net/corda/serialization/internal/CheckpointSerializationScheme.kt create mode 100644 testing/test-utils/src/main/kotlin/net/corda/testing/core/internal/CheckpointSerializationTestHelpers.kt diff --git a/.ci/api-current.txt b/.ci/api-current.txt index cafe828a66..13ae84905c 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -4445,8 +4445,6 @@ public interface net.corda.core.serialization.SerializationCustomSerializer public abstract PROXY toProxy(OBJ) ## public final class net.corda.core.serialization.SerializationDefaults extends java.lang.Object - @NotNull - public final net.corda.core.serialization.SerializationContext getCHECKPOINT_CONTEXT() @NotNull public final net.corda.core.serialization.SerializationContext getP2P_CONTEXT() @NotNull @@ -6883,8 +6881,6 @@ public final class net.corda.testing.core.SerializationEnvironmentRule extends j @NotNull public org.junit.runners.model.Statement apply(org.junit.runners.model.Statement, org.junit.runner.Description) @NotNull - public final net.corda.core.serialization.SerializationContext getCheckpointContext() - @NotNull public final net.corda.core.serialization.SerializationFactory getSerializationFactory() public static final net.corda.testing.core.SerializationEnvironmentRule$Companion Companion ## diff --git a/core-deterministic/build.gradle b/core-deterministic/build.gradle index 839384fdd7..f7cdb65e9c 100644 --- a/core-deterministic/build.gradle +++ b/core-deterministic/build.gradle @@ -50,6 +50,7 @@ task patchCore(type: Zip, dependsOn: coreJarTask) { from(zipTree(originalJar)) { exclude 'net/corda/core/internal/*ToggleField*.class' exclude 'net/corda/core/serialization/*SerializationFactory*.class' + exclude 'net/corda/core/serialization/internal/CheckpointSerializationFactory*.class' } reproducibleFileOrder = true diff --git a/core-deterministic/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationFactory.kt b/core-deterministic/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationFactory.kt new file mode 100644 index 0000000000..dbb6fb54c0 --- /dev/null +++ b/core-deterministic/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationFactory.kt @@ -0,0 +1,74 @@ +package net.corda.core.serialization.internal + +import net.corda.core.KeepForDJVM +import net.corda.core.serialization.SerializedBytes +import net.corda.core.utilities.ByteSequence +import java.io.NotSerializableException + +/** + * A deterministic version of [CheckpointSerializationFactory] that does not use thread-locals to manage serialization + * context. + */ +@KeepForDJVM +class CheckpointSerializationFactory( + private val scheme: CheckpointSerializationScheme +) { + + val defaultContext: CheckpointSerializationContext get() = _currentContext ?: effectiveSerializationEnv.checkpointContext + + private val creator: List = Exception().stackTrace.asList() + + /** + * Deserialize the bytes in to an object, using the prefixed bytes to determine the format. + * + * @param byteSequence The bytes to deserialize, including a format header prefix. + * @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown. + * @param context A context that configures various parameters to deserialization. + */ + @Throws(NotSerializableException::class) + fun deserialize(byteSequence: ByteSequence, clazz: Class, context: CheckpointSerializationContext): T { + return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) } + } + + /** + * Serialize an object to bytes using the preferred serialization format version from the context. + * + * @param obj The object to be serialized. + * @param context A context that configures various parameters to serialization, including the serialization format version. + */ + fun serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes { + return withCurrentContext(context) { scheme.serialize(obj, context) } + } + + override fun toString(): String { + return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}" + } + + override fun equals(other: Any?): Boolean { + return other is CheckpointSerializationFactory && other.scheme == this.scheme + } + + override fun hashCode(): Int = scheme.hashCode() + + private var _currentContext: CheckpointSerializationContext? = null + + /** + * Change the current context inside the block to that supplied. + */ + fun withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T { + val priorContext = _currentContext + if (context != null) _currentContext = context + try { + return block() + } finally { + if (context != null) _currentContext = priorContext + } + } + + companion object { + /** + * A default factory for serialization/deserialization. + */ + val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt index c5df7f7069..3a0ee16ce0 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt @@ -207,7 +207,13 @@ interface SerializationContext { * The use case that we are serializing for, since it influences the implementations chosen. */ @KeepForDJVM - enum class UseCase { P2P, RPCServer, RPCClient, Storage, Checkpoint, Testing } + enum class UseCase { + P2P, + RPCServer, + RPCClient, + Storage, + Testing + } } /** @@ -230,7 +236,6 @@ object SerializationDefaults { @DeleteForDJVM val RPC_SERVER_CONTEXT get() = effectiveSerializationEnv.rpcServerContext @DeleteForDJVM val RPC_CLIENT_CONTEXT get() = effectiveSerializationEnv.rpcClientContext @DeleteForDJVM val STORAGE_CONTEXT get() = effectiveSerializationEnv.storageContext - @DeleteForDJVM val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext } /** diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationAPI.kt new file mode 100644 index 0000000000..448d1ab25f --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/CheckpointSerializationAPI.kt @@ -0,0 +1,198 @@ +package net.corda.core.serialization.internal + +import net.corda.core.DeleteForDJVM +import net.corda.core.DoNotImplement +import net.corda.core.KeepForDJVM +import net.corda.core.crypto.SecureHash +import net.corda.core.serialization.* +import net.corda.core.utilities.ByteSequence +import net.corda.core.utilities.sequence +import java.io.NotSerializableException + + +object CheckpointSerializationDefaults { + @DeleteForDJVM + val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext + val CHECKPOINT_SERIALIZATION_FACTORY get() = effectiveSerializationEnv.checkpointSerializationFactory +} + +/** + * A class for serializing and deserializing objects at checkpoints, using Kryo serialization. + */ +@KeepForDJVM +class CheckpointSerializationFactory( + private val scheme: CheckpointSerializationScheme +) { + + val defaultContext: CheckpointSerializationContext get() = _currentContext.get() ?: effectiveSerializationEnv.checkpointContext + + private val creator: List = Exception().stackTrace.asList() + + /** + * Deserialize the bytes in to an object, using the prefixed bytes to determine the format. + * + * @param byteSequence The bytes to deserialize, including a format header prefix. + * @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown. + * @param context A context that configures various parameters to deserialization. + */ + fun deserialize(byteSequence: ByteSequence, clazz: Class, context: CheckpointSerializationContext): T { + return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) } + } + + /** + * Serialize an object to bytes using the preferred serialization format version from the context. + * + * @param obj The object to be serialized. + * @param context A context that configures various parameters to serialization, including the serialization format version. + */ + fun serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes { + return withCurrentContext(context) { scheme.serialize(obj, context) } + } + + override fun toString(): String { + return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}" + } + + override fun equals(other: Any?): Boolean { + return other is CheckpointSerializationFactory && other.scheme == this.scheme + } + + override fun hashCode(): Int = scheme.hashCode() + + private val _currentContext = ThreadLocal() + + /** + * Change the current context inside the block to that supplied. + */ + fun withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T { + val priorContext = _currentContext.get() + if (context != null) _currentContext.set(context) + try { + return block() + } finally { + if (context != null) _currentContext.set(priorContext) + } + } + + companion object { + val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory + } +} + +@KeepForDJVM +@DoNotImplement +interface CheckpointSerializationScheme { + @Throws(NotSerializableException::class) + fun deserialize(byteSequence: ByteSequence, clazz: Class, context: CheckpointSerializationContext): T + + @Throws(NotSerializableException::class) + fun serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes +} + +/** + * Parameters to checkpoint serialization and deserialization. + */ +@KeepForDJVM +@DoNotImplement +interface CheckpointSerializationContext { + /** + * If non-null, apply this encoding (typically compression) when serializing. + */ + val encoding: SerializationEncoding? + /** + * The class loader to use for deserialization. + */ + val deserializationClassLoader: ClassLoader + /** + * A whitelist that contains (mostly for security purposes) which classes can be serialized and deserialized. + */ + val whitelist: ClassWhitelist + /** + * A whitelist that determines (mostly for security purposes) whether a particular encoding may be used when deserializing. + */ + val encodingWhitelist: EncodingWhitelist + /** + * A map of any addition properties specific to the particular use case. + */ + val properties: Map + /** + * Duplicate references to the same object preserved in the wire format and when deserialized when this is true, + * otherwise they appear as new copies of the object. + */ + val objectReferencesEnabled: Boolean + + /** + * Helper method to return a new context based on this context with the property added. + */ + fun withProperty(property: Any, value: Any): CheckpointSerializationContext + + /** + * Helper method to return a new context based on this context with object references disabled. + */ + fun withoutReferences(): CheckpointSerializationContext + + /** + * Helper method to return a new context based on this context with the deserialization class loader changed. + */ + fun withClassLoader(classLoader: ClassLoader): CheckpointSerializationContext + + /** + * Helper method to return a new context based on this context with the appropriate class loader constructed from the passed attachment identifiers. + * (Requires the attachment storage to have been enabled). + */ + @Throws(MissingAttachmentsException::class) + fun withAttachmentsClassLoader(attachmentHashes: List): CheckpointSerializationContext + + /** + * Helper method to return a new context based on this context with the given class specifically whitelisted. + */ + fun withWhitelisted(clazz: Class<*>): CheckpointSerializationContext + + /** + * A shallow copy of this context but with the given (possibly null) encoding. + */ + fun withEncoding(encoding: SerializationEncoding?): CheckpointSerializationContext + + /** + * A shallow copy of this context but with the given encoding whitelist. + */ + fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist): CheckpointSerializationContext +} + +/* + * The following extension methods are disambiguated from the AMQP-serialization methods by requiring that an + * explicit [CheckpointSerializationContext] parameter be provided. + */ + +/* + * Convenience extension method for deserializing a ByteSequence, utilising the default factory. + */ +inline fun ByteSequence.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory, + context: CheckpointSerializationContext): T { + return serializationFactory.deserialize(this, T::class.java, context) +} + +/** + * Convenience extension method for deserializing SerializedBytes with type matching, utilising the default factory. + */ +inline fun SerializedBytes.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory, + context: CheckpointSerializationContext): T { + return serializationFactory.deserialize(this, T::class.java, context) +} + +/** + * Convenience extension method for deserializing a ByteArray, utilising the default factory. + */ +inline fun ByteArray.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory, + context: CheckpointSerializationContext): T { + require(isNotEmpty()) { "Empty bytes" } + return this.sequence().checkpointDeserialize(serializationFactory, context) +} + +/** + * Convenience extension method for serializing an object of type T, utilising the default factory. + */ +fun T.checkpointSerialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory, + context: CheckpointSerializationContext): SerializedBytes { + return serializationFactory.serialize(this, context) +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt index 28c6ad7900..441cd52be4 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/SerializationEnvironment.kt @@ -12,11 +12,12 @@ import net.corda.core.serialization.SerializationFactory @KeepForDJVM interface SerializationEnvironment { val serializationFactory: SerializationFactory + val checkpointSerializationFactory: CheckpointSerializationFactory val p2pContext: SerializationContext val rpcServerContext: SerializationContext val rpcClientContext: SerializationContext val storageContext: SerializationContext - val checkpointContext: SerializationContext + val checkpointContext: CheckpointSerializationContext } @KeepForDJVM @@ -26,18 +27,21 @@ open class SerializationEnvironmentImpl( rpcServerContext: SerializationContext? = null, rpcClientContext: SerializationContext? = null, storageContext: SerializationContext? = null, - checkpointContext: SerializationContext? = null) : SerializationEnvironment { + checkpointContext: CheckpointSerializationContext? = null, + checkpointSerializationFactory: CheckpointSerializationFactory? = null) : SerializationEnvironment { // Those that are passed in as null are never inited: override lateinit var rpcServerContext: SerializationContext override lateinit var rpcClientContext: SerializationContext override lateinit var storageContext: SerializationContext - override lateinit var checkpointContext: SerializationContext + override lateinit var checkpointContext: CheckpointSerializationContext + override lateinit var checkpointSerializationFactory: CheckpointSerializationFactory init { rpcServerContext?.let { this.rpcServerContext = it } rpcClientContext?.let { this.rpcClientContext = it } storageContext?.let { this.storageContext = it } checkpointContext?.let { this.checkpointContext = it } + checkpointSerializationFactory?.let { this.checkpointSerializationFactory = it } } } diff --git a/core/src/test/java/net/corda/core/flows/SerializationApiInJavaTest.java b/core/src/test/java/net/corda/core/flows/SerializationApiInJavaTest.java index 297adeff8f..55e66c1766 100644 --- a/core/src/test/java/net/corda/core/flows/SerializationApiInJavaTest.java +++ b/core/src/test/java/net/corda/core/flows/SerializationApiInJavaTest.java @@ -1,11 +1,14 @@ package net.corda.core.flows; +import net.corda.core.serialization.internal.CheckpointSerializationDefaults; +import net.corda.core.serialization.internal.CheckpointSerializationFactory; import net.corda.core.serialization.SerializationDefaults; import net.corda.core.serialization.SerializationFactory; import net.corda.testing.core.SerializationEnvironmentRule; import org.junit.Rule; import org.junit.Test; +import static net.corda.core.serialization.internal.CheckpointSerializationAPIKt.checkpointSerialize; import static net.corda.core.serialization.SerializationAPIKt.serialize; import static org.junit.Assert.assertNull; @@ -28,10 +31,13 @@ public class SerializationApiInJavaTest { public void enforceSerializationDefaultsApi() { SerializationDefaults defaults = SerializationDefaults.INSTANCE; SerializationFactory factory = defaults.getSERIALIZATION_FACTORY(); + + CheckpointSerializationDefaults checkpointDefaults = CheckpointSerializationDefaults.INSTANCE; + CheckpointSerializationFactory checkpointSerializationFactory = checkpointDefaults.getCHECKPOINT_SERIALIZATION_FACTORY(); serialize("hello", factory, defaults.getP2P_CONTEXT()); serialize("hello", factory, defaults.getRPC_SERVER_CONTEXT()); serialize("hello", factory, defaults.getRPC_CLIENT_CONTEXT()); serialize("hello", factory, defaults.getSTORAGE_CONTEXT()); - serialize("hello", factory, defaults.getCHECKPOINT_CONTEXT()); + checkpointSerialize("hello", checkpointSerializationFactory, checkpointDefaults.getCHECKPOINT_CONTEXT()); } } diff --git a/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt b/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt index fefb890213..debb6307f0 100644 --- a/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt +++ b/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt @@ -3,9 +3,10 @@ package net.corda.core.utilities import com.esotericsoftware.kryo.KryoException import net.corda.core.crypto.random63BitValue import net.corda.core.serialization.* +import net.corda.core.serialization.internal.checkpointDeserialize +import net.corda.core.serialization.internal.checkpointSerialize import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT -import net.corda.node.serialization.kryo.kryoMagic -import net.corda.serialization.internal.SerializationContextImpl +import net.corda.serialization.internal.CheckpointSerializationContextImpl import net.corda.testing.core.SerializationEnvironmentRule import org.assertj.core.api.Assertions.assertThat import org.junit.Rule @@ -24,12 +25,11 @@ class KotlinUtilsTest { @Rule val expectedEx: ExpectedException = ExpectedException.none() - private val KRYO_CHECKPOINT_NOWHITELIST_CONTEXT = SerializationContextImpl(kryoMagic, + private val KRYO_CHECKPOINT_NOWHITELIST_CONTEXT = CheckpointSerializationContextImpl( javaClass.classLoader, EmptyWhitelist, emptyMap(), true, - SerializationContext.UseCase.Checkpoint, null) @Test @@ -44,7 +44,7 @@ class KotlinUtilsTest { fun `checkpointing a transient property with non-capturing lambda`() { val original = NonCapturingTransientProperty() val originalVal = original.transientVal - val copy = original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_CONTEXT) + val copy = original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_CONTEXT) val copyVal = copy.transientVal assertThat(copyVal).isNotEqualTo(originalVal) assertThat(copy.transientVal).isEqualTo(copyVal) @@ -55,14 +55,14 @@ class KotlinUtilsTest { expectedEx.expect(KryoException::class.java) expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization") val original = NonCapturingTransientProperty() - original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT) + original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT) } @Test fun `checkpointing a transient property with capturing lambda`() { val original = CapturingTransientProperty("Hello") val originalVal = original.transientVal - val copy = original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_CONTEXT) + val copy = original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_CONTEXT) val copyVal = copy.transientVal assertThat(copyVal).isNotEqualTo(originalVal) assertThat(copy.transientVal).isEqualTo(copyVal) @@ -76,7 +76,7 @@ class KotlinUtilsTest { val original = CapturingTransientProperty("Hello") - original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT) + original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT) } private class NullTransientProperty { diff --git a/node/src/main/kotlin/net/corda/node/internal/CheckpointVerifier.kt b/node/src/main/kotlin/net/corda/node/internal/CheckpointVerifier.kt index ad2a7c903d..53f22b3147 100644 --- a/node/src/main/kotlin/net/corda/node/internal/CheckpointVerifier.kt +++ b/node/src/main/kotlin/net/corda/node/internal/CheckpointVerifier.kt @@ -5,11 +5,12 @@ import net.corda.core.crypto.SecureHash import net.corda.core.flows.FlowLogic import net.corda.core.node.ServiceHub import net.corda.core.serialization.SerializationDefaults -import net.corda.core.serialization.deserialize +import net.corda.core.serialization.internal.CheckpointSerializationDefaults +import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.statemachine.SubFlow import net.corda.node.services.statemachine.SubFlowVersion -import net.corda.serialization.internal.SerializeAsTokenContextImpl +import net.corda.serialization.internal.CheckpointSerializeAsTokenContextImpl import net.corda.serialization.internal.withTokenContext object CheckpointVerifier { @@ -19,13 +20,13 @@ object CheckpointVerifier { * @throws CheckpointIncompatibleException if any offending checkpoint is found. */ fun verifyCheckpointsCompatible(checkpointStorage: CheckpointStorage, currentCordapps: List, platformVersion: Int, serviceHub: ServiceHub, tokenizableServices: List) { - val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( - SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) + val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( + CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) ) checkpointStorage.getAllCheckpoints().forEach { (_, serializedCheckpoint) -> val checkpoint = try { - serializedCheckpoint.deserialize(context = checkpointSerializationContext) + serializedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext) } catch (e: Exception) { throw CheckpointIncompatibleException.CannotBeDeserialisedException(e) } diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index 0216158b2f..3d204ac7de 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -21,6 +21,7 @@ import net.corda.core.messaging.RPCOps import net.corda.core.node.NetworkParameters import net.corda.core.node.NodeInfo import net.corda.core.node.ServiceHub +import net.corda.core.serialization.internal.CheckpointSerializationFactory import net.corda.core.serialization.internal.SerializationEnvironmentImpl import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.NetworkHostAndPort @@ -37,7 +38,7 @@ import net.corda.node.internal.security.RPCSecurityManagerImpl import net.corda.node.internal.security.RPCSecurityManagerWithAdditionalUser import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT -import net.corda.node.serialization.kryo.KryoServerSerializationScheme +import net.corda.node.serialization.kryo.KryoSerializationScheme import net.corda.node.services.Permissions import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.ServiceHubInternal @@ -449,8 +450,8 @@ open class Node(configuration: NodeConfiguration, SerializationFactoryImpl().apply { registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps)) registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps)) - registerScheme(KryoServerSerializationScheme()) }, + checkpointSerializationFactory = CheckpointSerializationFactory(KryoSerializationScheme), p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader), storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader), diff --git a/node/src/main/kotlin/net/corda/node/serialization/kryo/CordaClassResolver.kt b/node/src/main/kotlin/net/corda/node/serialization/kryo/CordaClassResolver.kt index e9ce3e4d56..e3ff2584f7 100644 --- a/node/src/main/kotlin/net/corda/node/serialization/kryo/CordaClassResolver.kt +++ b/node/src/main/kotlin/net/corda/node/serialization/kryo/CordaClassResolver.kt @@ -8,8 +8,8 @@ import com.esotericsoftware.kryo.util.DefaultClassResolver import com.esotericsoftware.kryo.util.Util import net.corda.core.internal.kotlinObjectInstance import net.corda.core.internal.writer +import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.ClassWhitelist -import net.corda.core.serialization.SerializationContext import net.corda.core.utilities.contextLogger import net.corda.serialization.internal.AttachmentsClassLoader import net.corda.serialization.internal.MutableClassWhitelist @@ -25,7 +25,7 @@ import java.util.* /** * Corda specific class resolver which enables extra customisation for the purposes of serialization using Kryo */ -class CordaClassResolver(serializationContext: SerializationContext) : DefaultClassResolver() { +class CordaClassResolver(serializationContext: CheckpointSerializationContext) : DefaultClassResolver() { val whitelist: ClassWhitelist = TransientClassWhiteList(serializationContext.whitelist) // These classes are assignment-compatible Java equivalents of Kotlin classes. diff --git a/node/src/main/kotlin/net/corda/node/serialization/kryo/Kryo.kt b/node/src/main/kotlin/net/corda/node/serialization/kryo/Kryo.kt index e7b94b9635..674f4702f2 100644 --- a/node/src/main/kotlin/net/corda/node/serialization/kryo/Kryo.kt +++ b/node/src/main/kotlin/net/corda/node/serialization/kryo/Kryo.kt @@ -14,12 +14,11 @@ import net.corda.core.crypto.SecureHash import net.corda.core.crypto.TransactionSignature import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.SerializationContext -import net.corda.core.serialization.SerializationContext.UseCase.Checkpoint -import net.corda.core.serialization.SerializationContext.UseCase.Storage import net.corda.core.serialization.SerializeAsTokenContext import net.corda.core.serialization.SerializedBytes import net.corda.core.transactions.* import net.corda.core.utilities.OpaqueBytes +import net.corda.serialization.internal.checkUseCase import net.corda.serialization.internal.serializationContextKey import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -275,16 +274,9 @@ object SignedTransactionSerializer : Serializer() { } } -sealed class UseCaseSerializer(private val allowedUseCases: EnumSet) : Serializer() { - protected fun checkUseCase() { - net.corda.serialization.internal.checkUseCase(allowedUseCases) - } -} - @ThreadSafe -object PrivateKeySerializer : UseCaseSerializer(EnumSet.of(Storage, Checkpoint)) { +object PrivateKeySerializer : Serializer() { override fun write(kryo: Kryo, output: Output, obj: PrivateKey) { - checkUseCase() output.writeBytesWithLength(obj.encoded) } diff --git a/node/src/main/kotlin/net/corda/node/serialization/kryo/KryoSerializationScheme.kt b/node/src/main/kotlin/net/corda/node/serialization/kryo/KryoSerializationScheme.kt index 417fceabec..22edf8258e 100644 --- a/node/src/main/kotlin/net/corda/node/serialization/kryo/KryoSerializationScheme.kt +++ b/node/src/main/kotlin/net/corda/node/serialization/kryo/KryoSerializationScheme.kt @@ -10,10 +10,9 @@ import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.serializers.ClosureSerializer import net.corda.core.internal.uncheckedCast -import net.corda.core.serialization.ClassWhitelist -import net.corda.core.serialization.SerializationContext -import net.corda.core.serialization.SerializationDefaults -import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.* +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.CheckpointSerializationScheme import net.corda.core.utilities.ByteSequence import net.corda.serialization.internal.* import java.security.PublicKey @@ -32,46 +31,30 @@ private object AutoCloseableSerialisationDetector : Serializer() override fun read(kryo: Kryo, input: Input, type: Class) = throw IllegalStateException("Should not reach here!") } -abstract class AbstractKryoSerializationScheme : SerializationScheme { +object KryoSerializationScheme : CheckpointSerializationScheme { private val kryoPoolsForContexts = ConcurrentHashMap, KryoPool>() - protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool - protected abstract fun rpcServerKryoPool(context: SerializationContext): KryoPool - - // this can be overridden in derived serialization schemes - protected open val publicKeySerializer: Serializer = PublicKeySerializer - - private fun getPool(context: SerializationContext): KryoPool { + private fun getPool(context: CheckpointSerializationContext): KryoPool { return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { - when (context.useCase) { - SerializationContext.UseCase.Checkpoint -> - KryoPool.Builder { - val serializer = Fiber.getFiberSerializer(false) as KryoSerializer - val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) } - // TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that - val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true } - serializer.kryo.apply { - field.set(this, classResolver) - // don't allow overriding the public key serializer for checkpointing - DefaultKryoCustomizer.customize(this) - addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector) - register(ClosureSerializer.Closure::class.java, CordaClosureSerializer) - classLoader = it.second - } - }.build() - SerializationContext.UseCase.RPCClient -> - rpcClientKryoPool(context) - SerializationContext.UseCase.RPCServer -> - rpcServerKryoPool(context) - else -> - KryoPool.Builder { - DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(context)), publicKeySerializer).apply { classLoader = it.second } - }.build() - } + KryoPool.Builder { + val serializer = Fiber.getFiberSerializer(false) as KryoSerializer + val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) } + // TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that + val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true } + serializer.kryo.apply { + field.set(this, classResolver) + // don't allow overriding the public key serializer for checkpointing + DefaultKryoCustomizer.customize(this) + addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector) + register(ClosureSerializer.Closure::class.java, CordaClosureSerializer) + classLoader = it.second + } + }.build() + } } - private fun SerializationContext.kryo(task: Kryo.() -> T): T { + private fun CheckpointSerializationContext.kryo(task: Kryo.() -> T): T { return getPool(this).run { kryo -> kryo.context.ensureCapacity(properties.size) properties.forEach { kryo.context.put(it.key, it.value) } @@ -83,7 +66,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme { } } - override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: CheckpointSerializationContext): T { val dataBytes = kryoMagic.consume(byteSequence) ?: throw KryoException("Serialized bytes header does not match expected format.") return context.kryo { @@ -111,7 +94,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme { } } - override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + override fun serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes { return context.kryo { SerializedBytes(kryoOutput { kryoMagic.writeTo(this) @@ -131,13 +114,11 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme { } } -val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl( - kryoMagic, +val KRYO_CHECKPOINT_CONTEXT = CheckpointSerializationContextImpl( SerializationDefaults.javaClass.classLoader, QuasarWhitelist, emptyMap(), true, - SerializationContext.UseCase.Checkpoint, null, AlwaysAcceptEncodingWhitelist ) diff --git a/node/src/main/kotlin/net/corda/node/serialization/kryo/KryoServerSerializationScheme.kt b/node/src/main/kotlin/net/corda/node/serialization/kryo/KryoServerSerializationScheme.kt deleted file mode 100644 index 86a4226812..0000000000 --- a/node/src/main/kotlin/net/corda/node/serialization/kryo/KryoServerSerializationScheme.kt +++ /dev/null @@ -1,14 +0,0 @@ -package net.corda.node.serialization.kryo - -import com.esotericsoftware.kryo.pool.KryoPool -import net.corda.core.serialization.SerializationContext -import net.corda.serialization.internal.CordaSerializationMagic - -class KryoServerSerializationScheme : AbstractKryoSerializationScheme() { - override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean { - return magic == kryoMagic && target == SerializationContext.UseCase.Checkpoint - } - - override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException() - override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException() -} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt index ec73a4c3bb..00a0406dbe 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt @@ -4,9 +4,9 @@ import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Suspendable import com.codahale.metrics.* import net.corda.core.internal.concurrent.thenMatch -import net.corda.core.serialization.SerializationContext -import net.corda.core.serialization.SerializedBytes -import net.corda.core.serialization.serialize +import net.corda.core.serialization.* +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.utilities.contextLogger import net.corda.core.utilities.trace import net.corda.node.services.api.CheckpointStorage @@ -27,7 +27,7 @@ class ActionExecutorImpl( private val checkpointStorage: CheckpointStorage, private val flowMessaging: FlowMessaging, private val stateMachineManager: StateMachineManagerInternal, - private val checkpointSerializationContext: SerializationContext, + private val checkpointSerializationContext: CheckpointSerializationContext, metrics: MetricRegistry ) : ActionExecutor { @@ -237,7 +237,7 @@ class ActionExecutorImpl( } private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes { - return checkpoint.serialize(context = checkpointSerializationContext) + return checkpoint.checkpointSerialize(context = checkpointSerializationContext) } private fun cancelFlowTimeout(action: Action.CancelFlowTimeout) { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 4cbd80cfd2..3b099e5c9a 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -12,8 +12,8 @@ import net.corda.core.cordapp.Cordapp import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.internal.* -import net.corda.core.serialization.SerializationContext -import net.corda.core.serialization.serialize +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.utilities.Try import net.corda.core.utilities.debug import net.corda.core.utilities.trace @@ -69,7 +69,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, val actionExecutor: ActionExecutor, val stateMachine: StateMachine, val serviceHub: ServiceHubInternal, - val checkpointSerializationContext: SerializationContext, + val checkpointSerializationContext: CheckpointSerializationContext, val unfinishedFibers: ReusableLatch ) @@ -369,7 +369,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, Event.Suspend( ioRequest = ioRequest, maySkipCheckpoint = skipPersistingCheckpoint, - fiber = this.serialize(context = serializationContext.value) + fiber = this.checkpointSerialize(context = serializationContext.value) ) } catch (throwable: Throwable) { Event.Error(throwable) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt index e1dbc171fd..a47799e0aa 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt @@ -19,6 +19,10 @@ import net.corda.core.internal.concurrent.map import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.DataFeed import net.corda.core.serialization.* +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.CheckpointSerializationDefaults +import net.corda.core.serialization.internal.checkpointDeserialize +import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.Try import net.corda.core.utilities.contextLogger @@ -36,7 +40,7 @@ import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.injectOldProgressTracker import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction -import net.corda.serialization.internal.SerializeAsTokenContextImpl +import net.corda.serialization.internal.CheckpointSerializeAsTokenContextImpl import net.corda.serialization.internal.withTokenContext import org.apache.activemq.artemis.utils.ReusableLatch import rx.Observable @@ -103,7 +107,7 @@ class SingleThreadedStateMachineManager( private val transitionExecutor = makeTransitionExecutor() private val ourSenderUUID = serviceHub.networkService.ourSenderUUID - private var checkpointSerializationContext: SerializationContext? = null + private var checkpointSerializationContext: CheckpointSerializationContext? = null private var actionExecutor: ActionExecutor? = null override val allStateMachines: List> @@ -122,8 +126,8 @@ class SingleThreadedStateMachineManager( override fun start(tokenizableServices: List) { checkQuasarJavaAgentPresence() - val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( - SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) + val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( + CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) ) this.checkpointSerializationContext = checkpointSerializationContext this.actionExecutor = makeActionExecutor(checkpointSerializationContext) @@ -531,7 +535,7 @@ class SingleThreadedStateMachineManager( val resultFuture = openFuture() flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture)) flowLogic.stateMachine = flowStateMachineImpl - val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!) + val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext!!) val flowCorDappVersion = createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion) @@ -613,7 +617,7 @@ class SingleThreadedStateMachineManager( private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes): Checkpoint? { return try { - serializedCheckpoint.deserialize(context = checkpointSerializationContext!!) + serializedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext!!) } catch (exception: Throwable) { logger.error("Encountered unrestorable checkpoint!", exception) null @@ -658,7 +662,7 @@ class SingleThreadedStateMachineManager( val resultFuture = openFuture() val fiber = when (flowState) { is FlowState.Unstarted -> { - val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!) + val logic = flowState.frozenFlowLogic.checkpointDeserialize(context = checkpointSerializationContext!!) val state = StateMachineState( checkpoint = checkpoint, pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), @@ -677,7 +681,7 @@ class SingleThreadedStateMachineManager( fiber } is FlowState.Started -> { - val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!) + val fiber = flowState.frozenFiber.checkpointDeserialize(context = checkpointSerializationContext!!) val state = StateMachineState( checkpoint = checkpoint, pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), @@ -742,7 +746,7 @@ class SingleThreadedStateMachineManager( } } - private fun makeActionExecutor(checkpointSerializationContext: SerializationContext): ActionExecutor { + private fun makeActionExecutor(checkpointSerializationContext: CheckpointSerializationContext): ActionExecutor { return ActionExecutorImpl( serviceHub, checkpointStorage, diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt index 67b1733a90..37033977de 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/FiberDeserializationCheckingInterceptor.kt @@ -2,9 +2,9 @@ package net.corda.node.services.statemachine.interceptors import co.paralleluniverse.fibers.Suspendable import net.corda.core.flows.StateMachineRunId -import net.corda.core.serialization.SerializationContext -import net.corda.core.serialization.SerializedBytes -import net.corda.core.serialization.deserialize +import net.corda.core.serialization.* +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.checkpointDeserialize import net.corda.core.utilities.contextLogger import net.corda.node.services.statemachine.ActionExecutor import net.corda.node.services.statemachine.Event @@ -68,7 +68,7 @@ class FiberDeserializationChecker { private val jobQueue = LinkedBlockingQueue() private var foundUnrestorableFibers: Boolean = false - fun start(checkpointSerializationContext: SerializationContext) { + fun start(checkpointSerializationContext: CheckpointSerializationContext) { require(checkerThread == null) checkerThread = thread(name = "FiberDeserializationChecker") { while (true) { @@ -76,7 +76,7 @@ class FiberDeserializationChecker { when (job) { is Job.Check -> { try { - job.serializedFiber.deserialize(context = checkpointSerializationContext) + job.serializedFiber.checkpointDeserialize(context = checkpointSerializationContext) } catch (throwable: Throwable) { log.error("Encountered unrestorable checkpoint!", throwable) foundUnrestorableFibers = true diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/RaftTransactionCommitLog.kt b/node/src/main/kotlin/net/corda/node/services/transactions/RaftTransactionCommitLog.kt index 59a0734ceb..72fa52bdc1 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/RaftTransactionCommitLog.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/RaftTransactionCommitLog.kt @@ -20,10 +20,10 @@ import net.corda.core.flows.StateConsumptionDetails import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.notary.isConsumedByTheSameTx import net.corda.core.internal.notary.validateTimeWindow -import net.corda.core.serialization.SerializationDefaults -import net.corda.core.serialization.SerializationFactory -import net.corda.core.serialization.deserialize -import net.corda.core.serialization.serialize +import net.corda.core.serialization.* +import net.corda.core.serialization.internal.CheckpointSerializationDefaults +import net.corda.core.serialization.internal.CheckpointSerializationFactory +import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug @@ -200,11 +200,11 @@ class RaftTransactionCommitLog( } class CordaKryoSerializer : TypeSerializer { - private val context = SerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY) - private val factory = SerializationFactory.defaultFactory + private val context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY) + private val factory = CheckpointSerializationFactory.defaultFactory override fun write(obj: T, buffer: BufferOutput<*>, serializer: Serializer) { - val serialized = obj.serialize(context = context) + val serialized = obj.checkpointSerialize(context = context) buffer.writeInt(serialized.size) buffer.write(serialized.bytes) } diff --git a/node/src/test/kotlin/net/corda/node/serialization/kryo/KryoTests.kt b/node/src/test/kotlin/net/corda/node/serialization/kryo/KryoTests.kt index 79df364b5c..5598f38f67 100644 --- a/node/src/test/kotlin/net/corda/node/serialization/kryo/KryoTests.kt +++ b/node/src/test/kotlin/net/corda/node/serialization/kryo/KryoTests.kt @@ -5,7 +5,6 @@ import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.KryoSerializable import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool import com.google.common.primitives.Ints import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.whenever @@ -13,6 +12,10 @@ import net.corda.core.contracts.PrivacySalt import net.corda.core.crypto.* import net.corda.core.internal.FetchDataFlow import net.corda.core.serialization.* +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.CheckpointSerializationFactory +import net.corda.core.serialization.internal.checkpointDeserialize +import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.sequence @@ -36,16 +39,6 @@ import java.util.* import kotlin.collections.ArrayList import kotlin.test.* -class TestScheme : AbstractKryoSerializationScheme() { - override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean { - return magic == kryoMagic && target != SerializationContext.UseCase.RPCClient - } - - override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException() - - override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException() -} - @RunWith(Parameterized::class) class KryoTests(private val compression: CordaSerializationEncoding?) { companion object { @@ -55,18 +48,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { fun compression() = arrayOf(null) + CordaSerializationEncoding.values() } - private lateinit var factory: SerializationFactory - private lateinit var context: SerializationContext + private lateinit var factory: CheckpointSerializationFactory + private lateinit var context: CheckpointSerializationContext @Before fun setup() { - factory = SerializationFactoryImpl().apply { registerScheme(TestScheme()) } - context = SerializationContextImpl(kryoMagic, + factory = CheckpointSerializationFactory(KryoSerializationScheme) + context = CheckpointSerializationContextImpl( javaClass.classLoader, AllWhitelist, emptyMap(), true, - SerializationContext.UseCase.Storage, compression, rigorousMock().also { if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression) @@ -77,15 +69,15 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { fun `simple data class`() { val birthday = Instant.parse("1984-04-17T00:30:00.00Z") val mike = Person("mike", birthday) - val bits = mike.serialize(factory, context) - assertThat(bits.deserialize(factory, context)).isEqualTo(Person("mike", birthday)) + val bits = mike.checkpointSerialize(factory, context) + assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("mike", birthday)) } @Test fun `null values`() { val bob = Person("bob", null) - val bits = bob.serialize(factory, context) - assertThat(bits.deserialize(factory, context)).isEqualTo(Person("bob", null)) + val bits = bob.checkpointSerialize(factory, context) + assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("bob", null)) } @Test @@ -93,10 +85,10 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { val noReferencesContext = context.withoutReferences() val obj : ByteSequence = Ints.toByteArray(0x01234567).sequence() val originalList : ArrayList = ArrayList().apply { this += obj } - val deserialisedList = originalList.serialize(factory, noReferencesContext).deserialize(factory, noReferencesContext) + val deserialisedList = originalList.checkpointSerialize(factory, noReferencesContext).checkpointDeserialize(factory, noReferencesContext) originalList += obj deserialisedList += obj - assertThat(deserialisedList.serialize(factory, noReferencesContext)).isEqualTo(originalList.serialize(factory, noReferencesContext)) + assertThat(deserialisedList.checkpointSerialize(factory, noReferencesContext)).isEqualTo(originalList.checkpointSerialize(factory, noReferencesContext)) } @Test @@ -113,14 +105,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { this += instant this += instant } - assertThat(listWithSameInstances.serialize(factory, noReferencesContext)).isEqualTo(listWithCopies.serialize(factory, noReferencesContext)) + assertThat(listWithSameInstances.checkpointSerialize(factory, noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(factory, noReferencesContext)) } @Test fun `cyclic object graph`() { val cyclic = Cyclic(3) - val bits = cyclic.serialize(factory, context) - assertThat(bits.deserialize(factory, context)).isEqualTo(cyclic) + val bits = cyclic.checkpointSerialize(factory, context) + assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(cyclic) } @Test @@ -132,7 +124,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { signature.verify(bitsToSign) assertThatThrownBy { signature.verify(wrongBits) } - val deserialisedKeyPair = keyPair.serialize(factory, context).deserialize(factory, context) + val deserialisedKeyPair = keyPair.checkpointSerialize(factory, context).checkpointDeserialize(factory, context) val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign) deserialisedSignature.verify(bitsToSign) assertThatThrownBy { deserialisedSignature.verify(wrongBits) } @@ -140,28 +132,28 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { @Test fun `write and read Kotlin object singleton`() { - val serialised = TestSingleton.serialize(factory, context) - val deserialised = serialised.deserialize(factory, context) + val serialised = TestSingleton.checkpointSerialize(factory, context) + val deserialised = serialised.checkpointDeserialize(factory, context) assertThat(deserialised).isSameAs(TestSingleton) } @Test fun `check Kotlin EmptyList can be serialised`() { - val deserialisedList: List = emptyList().serialize(factory, context).deserialize(factory, context) + val deserialisedList: List = emptyList().checkpointSerialize(factory, context).checkpointDeserialize(factory, context) assertEquals(0, deserialisedList.size) assertEquals(Collections.emptyList().javaClass, deserialisedList.javaClass) } @Test fun `check Kotlin EmptySet can be serialised`() { - val deserialisedSet: Set = emptySet().serialize(factory, context).deserialize(factory, context) + val deserialisedSet: Set = emptySet().checkpointSerialize(factory, context).checkpointDeserialize(factory, context) assertEquals(0, deserialisedSet.size) assertEquals(Collections.emptySet().javaClass, deserialisedSet.javaClass) } @Test fun `check Kotlin EmptyMap can be serialised`() { - val deserialisedMap: Map = emptyMap().serialize(factory, context).deserialize(factory, context) + val deserialisedMap: Map = emptyMap().checkpointSerialize(factory, context).checkpointDeserialize(factory, context) assertEquals(0, deserialisedMap.size) assertEquals(Collections.emptyMap().javaClass, deserialisedMap.javaClass) } @@ -169,7 +161,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { @Test fun `InputStream serialisation`() { val rubbish = ByteArray(12345) { (it * it * 0.12345).toByte() } - val readRubbishStream: InputStream = rubbish.inputStream().serialize(factory, context).deserialize(factory, context) + val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(factory, context).checkpointDeserialize(factory, context) for (i in 0..12344) { assertEquals(rubbish[i], readRubbishStream.read().toByte()) } @@ -179,7 +171,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { @Test fun `InputStream serialisation does not write trailing garbage`() { val byteArrays = listOf("123", "456").map { it.toByteArray() } - val streams = byteArrays.map { it.inputStream() }.serialize(factory, context).deserialize(factory, context).iterator() + val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(factory, context).checkpointDeserialize(factory, context).iterator() byteArrays.forEach { assertArrayEquals(it, streams.next().readBytes()) } assertFalse(streams.hasNext()) } @@ -190,16 +182,16 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { val testBytes = testString.toByteArray() val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID)) - val serializedMetaData = meta.serialize(factory, context).bytes - val meta2 = serializedMetaData.deserialize(factory, context) + val serializedMetaData = meta.checkpointSerialize(factory, context).bytes + val meta2 = serializedMetaData.checkpointDeserialize(factory, context) assertEquals(meta2, meta) } @Test fun `serialize - deserialize Logger`() { - val storageContext: SerializationContext = context // TODO: make it storage context + val storageContext: CheckpointSerializationContext = context val logger = LoggerFactory.getLogger("aName") - val logger2 = logger.serialize(factory, storageContext).deserialize(factory, storageContext) + val logger2 = logger.checkpointSerialize(factory, storageContext).checkpointDeserialize(factory, storageContext) assertEquals(logger.name, logger2.name) assertTrue(logger === logger2) } @@ -211,7 +203,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { SecureHash.sha256(rubbish), rubbish.size, rubbish.inputStream() - ).serialize(factory, context).deserialize(factory, context) + ).checkpointSerialize(factory, context).checkpointDeserialize(factory, context) for (i in 0..12344) { assertEquals(rubbish[i], readRubbishStream.read().toByte()) } @@ -238,8 +230,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 )) - val serializedBytes = expected.serialize(factory, context) - val actual = serializedBytes.deserialize(factory, context) + val serializedBytes = expected.checkpointSerialize(factory, context) + val actual = serializedBytes.checkpointDeserialize(factory, context) assertEquals(expected, actual) } @@ -286,15 +278,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { } } Tmp() - val factory = SerializationFactoryImpl().apply { registerScheme(TestScheme()) } - val context = SerializationContextImpl(kryoMagic, + val factory = CheckpointSerializationFactory(KryoSerializationScheme) + val context = CheckpointSerializationContextImpl( javaClass.classLoader, AllWhitelist, emptyMap(), true, - SerializationContext.UseCase.P2P, null) - pt.serialize(factory, context) + pt.checkpointSerialize(factory, context) } @Test @@ -302,7 +293,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { val exception = IllegalArgumentException("fooBar") val toBeSuppressedOnSenderSide = IllegalStateException("bazz1") exception.addSuppressed(toBeSuppressedOnSenderSide) - val exception2 = exception.serialize(factory, context).deserialize(factory, context) + val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context) assertEquals(exception.message, exception2.message) assertEquals(1, exception2.suppressed.size) @@ -317,7 +308,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { @Test fun `serialize - deserialize Exception no suppressed`() { val exception = IllegalArgumentException("fooBar") - val exception2 = exception.serialize(factory, context).deserialize(factory, context) + val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context) assertEquals(exception.message, exception2.message) assertEquals(0, exception2.suppressed.size) @@ -331,7 +322,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { fun `serialize - deserialize HashNotFound`() { val randomHash = SecureHash.randomSHA256() val exception = FetchDataFlow.HashNotFound(randomHash) - val exception2 = exception.serialize(factory, context).deserialize(factory, context) + val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context) assertEquals(randomHash, exception2.requested) } @@ -339,17 +330,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { fun `compression has the desired effect`() { compression ?: return val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it } - val compressed = data.serialize(factory, context) + val compressed = data.checkpointSerialize(factory, context) assertEquals(.5, compressed.size.toDouble() / data.size, .03) - assertArrayEquals(data, compressed.deserialize(factory, context)) + assertArrayEquals(data, compressed.checkpointDeserialize(factory, context)) } @Test fun `a particular encoding can be banned for deserialization`() { compression ?: return doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression) - val compressed = "whatever".serialize(factory, context) - catchThrowable { compressed.deserialize(factory, context) }.run { + val compressed = "whatever".checkpointSerialize(factory, context) + catchThrowable { compressed.checkpointDeserialize(factory, context) }.run { assertSame(KryoException::class.java, javaClass) assertEquals(encodingNotPermittedFormat.format(compression), message) } @@ -360,8 +351,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) { class Holder(val holder: ByteArray) val obj = Holder(ByteArray(20000)) - val uncompressedSize = obj.serialize(factory, context.withEncoding(null)).size - val compressedSize = obj.serialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size + val uncompressedSize = obj.checkpointSerialize(factory, context.withEncoding(null)).size + val compressedSize = obj.checkpointSerialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size // If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade. assertEquals(20222, uncompressedSize) assertEquals(1111, compressedSize) diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt index bc02aa19f0..7fd4072d38 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt @@ -3,9 +3,9 @@ package net.corda.node.services.persistence import net.corda.core.context.InvocationContext import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId -import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.internal.CheckpointSerializationDefaults import net.corda.core.serialization.SerializedBytes -import net.corda.core.serialization.serialize +import net.corda.core.serialization.internal.checkpointSerialize import net.corda.node.internal.CheckpointIncompatibleException import net.corda.node.internal.CheckpointVerifier import net.corda.node.internal.configureDatabase @@ -189,9 +189,9 @@ class DBCheckpointStorageTests { val logic: FlowLogic<*> = object : FlowLogic() { override fun call() {} } - val frozenLogic = logic.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT) + val frozenLogic = logic.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, ALICE, SubFlowVersion.CoreFlow(version)).getOrThrow() - return id to checkpoint.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT) + return id to checkpoint.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT) } } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/CheckpointSerializationScheme.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/CheckpointSerializationScheme.kt new file mode 100644 index 0000000000..c370084e7a --- /dev/null +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/CheckpointSerializationScheme.kt @@ -0,0 +1,49 @@ +package net.corda.serialization.internal + +import net.corda.core.KeepForDJVM +import net.corda.core.crypto.SecureHash +import net.corda.core.serialization.* +import net.corda.core.serialization.internal.CheckpointSerializationContext + +@KeepForDJVM +data class CheckpointSerializationContextImpl @JvmOverloads constructor( + override val deserializationClassLoader: ClassLoader, + override val whitelist: ClassWhitelist, + override val properties: Map, + override val objectReferencesEnabled: Boolean, + override val encoding: SerializationEncoding?, + override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) : CheckpointSerializationContext { + private val builder = AttachmentsClassLoaderBuilder(properties, deserializationClassLoader) + + /** + * {@inheritDoc} + * + * We need to cache the AttachmentClassLoaders to avoid too many contexts, since the class loader is part of cache key for the context. + */ + override fun withAttachmentsClassLoader(attachmentHashes: List): CheckpointSerializationContext { + properties[attachmentsClassLoaderEnabledPropertyName] as? Boolean == true || return this + val classLoader = builder.build(attachmentHashes) ?: return this + return withClassLoader(classLoader) + } + + override fun withProperty(property: Any, value: Any): CheckpointSerializationContext { + return copy(properties = properties + (property to value)) + } + + override fun withoutReferences(): CheckpointSerializationContext { + return copy(objectReferencesEnabled = false) + } + + override fun withClassLoader(classLoader: ClassLoader): CheckpointSerializationContext { + return copy(deserializationClassLoader = classLoader) + } + + override fun withWhitelisted(clazz: Class<*>): CheckpointSerializationContext { + return copy(whitelist = object : ClassWhitelist { + override fun hasListed(type: Class<*>): Boolean = whitelist.hasListed(type) || type.name == clazz.name + }) + } + + override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding) + override fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist) = copy(encodingWhitelist = encodingWhitelist) +} \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializeAsTokenContextImpl.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializeAsTokenContextImpl.kt index d028e91168..785ce47597 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializeAsTokenContextImpl.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializeAsTokenContextImpl.kt @@ -3,14 +3,14 @@ package net.corda.serialization.internal import net.corda.core.DeleteForDJVM import net.corda.core.node.ServiceHub -import net.corda.core.serialization.SerializationContext -import net.corda.core.serialization.SerializationFactory -import net.corda.core.serialization.SerializeAsToken -import net.corda.core.serialization.SerializeAsTokenContext +import net.corda.core.serialization.* +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.CheckpointSerializationFactory val serializationContextKey = SerializeAsTokenContext::class.java fun SerializationContext.withTokenContext(serializationContext: SerializeAsTokenContext): SerializationContext = this.withProperty(serializationContextKey, serializationContext) +fun CheckpointSerializationContext.withTokenContext(serializationContext: SerializeAsTokenContext): CheckpointSerializationContext = this.withProperty(serializationContextKey, serializationContext) /** * A context for mapping SerializationTokens to/from SerializeAsTokens. @@ -55,6 +55,53 @@ class SerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: Ser } } + override fun getSingleton(className: String) = classNameToSingleton[className] + ?: throw IllegalStateException("Unable to find tokenized instance of $className in context $this") +} + +/** + * A context for mapping SerializationTokens to/from SerializeAsTokens. + * + * A context is initialised with an object containing all the instances of [SerializeAsToken] to eagerly register all the tokens. + * In our case this can be the [ServiceHub]. + * + * Then it is a case of using the companion object methods on [SerializeAsTokenSerializer] to set and clear context as necessary + * when serializing to enable/disable tokenization. + */ +@DeleteForDJVM +class CheckpointSerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: SerializeAsTokenContext.() -> Unit) : SerializeAsTokenContext { + constructor(toBeTokenized: Any, serializationFactory: CheckpointSerializationFactory, context: CheckpointSerializationContext, serviceHub: ServiceHub) : this(serviceHub, { + serializationFactory.serialize(toBeTokenized, context.withTokenContext(this)) + }) + + private val classNameToSingleton = mutableMapOf() + private var readOnly = false + + init { + /** + * Go ahead and eagerly serialize the object to register all of the tokens in the context. + * + * This results in the toToken() method getting called for any [SingletonSerializeAsToken] instances which + * are encountered in the object graph as they are serialized and will therefore register the token to + * object mapping for those instances. We then immediately set the readOnly flag to stop further adhoc or + * accidental registrations from occuring as these could not be deserialized in a deserialization-first + * scenario if they are not part of this iniital context construction serialization. + */ + init(this) + readOnly = true + } + + override fun putSingleton(toBeTokenized: SerializeAsToken) { + val className = toBeTokenized.javaClass.name + if (className !in classNameToSingleton) { + // Only allowable if we are in SerializeAsTokenContext init (readOnly == false) + if (readOnly) { + throw UnsupportedOperationException("Attempt to write token for lazy registered $className. All tokens should be registered during context construction.") + } + classNameToSingleton[className] = toBeTokenized + } + } + override fun getSingleton(className: String) = classNameToSingleton[className] ?: throw IllegalStateException("Unable to find tokenized instance of $className in context $this") } \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/UseCaseAwareness.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/UseCaseAwareness.kt index ebe030b81d..2ce03e1e3b 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/UseCaseAwareness.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/UseCaseAwareness.kt @@ -13,3 +13,11 @@ fun checkUseCase(allowedUseCases: EnumSet) { throw IllegalStateException("UseCase '${currentContext.useCase}' is not within '$allowedUseCases'") } } + +fun checkUseCase(allowedUseCase: SerializationContext.UseCase) { + val currentContext: SerializationContext = SerializationFactory.currentFactory?.currentContext + ?: throw IllegalStateException("Current context is not set") + if (allowedUseCase != currentContext.useCase) { + throw IllegalStateException("UseCase '${currentContext.useCase}' is not '$allowedUseCase'") + } +} diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPSerializationScheme.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPSerializationScheme.kt index f35c808077..94c0d2223f 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPSerializationScheme.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPSerializationScheme.kt @@ -163,8 +163,6 @@ abstract class AbstractAMQPSerializationScheme( return synchronized(serializerFactoriesForContexts) { serializerFactoriesForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { when (context.useCase) { - SerializationContext.UseCase.Checkpoint -> - throw IllegalStateException("AMQP should not be used for checkpoint serialization.") SerializationContext.UseCase.RPCClient -> rpcClientSerializerFactory(context) SerializationContext.UseCase.RPCServer -> diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PrivateKeySerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PrivateKeySerializer.kt index e87c679570..7bf9bbf344 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PrivateKeySerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PrivateKeySerializer.kt @@ -2,7 +2,6 @@ package net.corda.serialization.internal.amqp.custom import net.corda.core.crypto.Crypto import net.corda.core.serialization.SerializationContext -import net.corda.core.serialization.SerializationContext.UseCase.Checkpoint import net.corda.core.serialization.SerializationContext.UseCase.Storage import net.corda.serialization.internal.amqp.* import net.corda.serialization.internal.checkUseCase @@ -13,14 +12,12 @@ import java.util.* object PrivateKeySerializer : CustomSerializer.Implements(PrivateKey::class.java) { - private val allowedUseCases = EnumSet.of(Storage, Checkpoint) - override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, descriptor, emptyList()))) override fun writeDescribedObject(obj: PrivateKey, data: Data, type: Type, output: SerializationOutput, context: SerializationContext ) { - checkUseCase(allowedUseCases) + checkUseCase(Storage) output.writeObject(obj.encoded, data, clazz, context) } diff --git a/serialization/src/test/java/net/corda/serialization/internal/ForbiddenLambdaSerializationTests.java b/serialization/src/test/java/net/corda/serialization/internal/ForbiddenLambdaSerializationTests.java index 130870d544..03b0117c07 100644 --- a/serialization/src/test/java/net/corda/serialization/internal/ForbiddenLambdaSerializationTests.java +++ b/serialization/src/test/java/net/corda/serialization/internal/ForbiddenLambdaSerializationTests.java @@ -4,7 +4,6 @@ import com.google.common.collect.Maps; import net.corda.core.serialization.SerializationContext; import net.corda.core.serialization.SerializationFactory; import net.corda.core.serialization.SerializedBytes; -import net.corda.serialization.internal.amqp.AMQPNotSerializableException; import net.corda.serialization.internal.amqp.SchemaKt; import net.corda.testing.core.SerializationEnvironmentRule; import org.junit.Before; @@ -20,8 +19,10 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.ThrowableAssert.catchThrowable; public final class ForbiddenLambdaSerializationTests { + private EnumSet contexts = EnumSet.complementOf( - EnumSet.of(SerializationContext.UseCase.Checkpoint, SerializationContext.UseCase.Testing)); + EnumSet.of(SerializationContext.UseCase.Testing)); + @Rule public final SerializationEnvironmentRule testSerialization = new SerializationEnvironmentRule(); private SerializationFactory factory; diff --git a/serialization/src/test/java/net/corda/serialization/internal/LambdaCheckpointSerializationTest.java b/serialization/src/test/java/net/corda/serialization/internal/LambdaCheckpointSerializationTest.java index 1cae8762bb..feab89ad92 100644 --- a/serialization/src/test/java/net/corda/serialization/internal/LambdaCheckpointSerializationTest.java +++ b/serialization/src/test/java/net/corda/serialization/internal/LambdaCheckpointSerializationTest.java @@ -1,11 +1,11 @@ package net.corda.serialization.internal; -import net.corda.core.serialization.SerializationContext; -import net.corda.core.serialization.SerializationFactory; -import net.corda.core.serialization.SerializedBytes; +import net.corda.core.serialization.*; +import net.corda.core.serialization.internal.CheckpointSerializationContext; +import net.corda.core.serialization.internal.CheckpointSerializationFactory; import net.corda.node.serialization.kryo.CordaClosureSerializer; -import net.corda.node.serialization.kryo.KryoSerializationSchemeKt; import net.corda.testing.core.SerializationEnvironmentRule; +import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -18,21 +18,22 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.ThrowableAssert.catchThrowable; public final class LambdaCheckpointSerializationTest { + @Rule - public final SerializationEnvironmentRule testSerialization = new SerializationEnvironmentRule(); - private SerializationFactory factory; - private SerializationContext context; + public final CheckpointSerializationEnvironmentRule testCheckpointSerialization = + new CheckpointSerializationEnvironmentRule(); + + private CheckpointSerializationFactory factory; + private CheckpointSerializationContext context; @Before public void setup() { - factory = testSerialization.getSerializationFactory(); - context = new SerializationContextImpl( - KryoSerializationSchemeKt.getKryoMagic(), + factory = testCheckpointSerialization.getCheckpointSerializationFactory(); + context = new CheckpointSerializationContextImpl( getClass().getClassLoader(), AllWhitelist.INSTANCE, Collections.emptyMap(), true, - SerializationContext.UseCase.Checkpoint, null ); } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/ContractAttachmentSerializerTest.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/ContractAttachmentSerializerTest.kt index a4d17cc52e..73b799217d 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/ContractAttachmentSerializerTest.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/ContractAttachmentSerializerTest.kt @@ -3,8 +3,13 @@ package net.corda.serialization.internal import net.corda.core.contracts.ContractAttachment import net.corda.core.identity.CordaX500Name import net.corda.core.serialization.* +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.CheckpointSerializationFactory +import net.corda.core.serialization.internal.checkpointDeserialize +import net.corda.core.serialization.internal.checkpointSerialize import net.corda.testing.contracts.DummyContract import net.corda.testing.core.SerializationEnvironmentRule +import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule import net.corda.testing.internal.rigorousMock import net.corda.testing.node.MockServices import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY @@ -17,28 +22,29 @@ import org.junit.Test import kotlin.test.assertEquals class ContractAttachmentSerializerTest { + @Rule @JvmField - val testSerialization = SerializationEnvironmentRule() + val testCheckpointSerialization = CheckpointSerializationEnvironmentRule() - private lateinit var factory: SerializationFactory - private lateinit var context: SerializationContext - private lateinit var contextWithToken: SerializationContext + private lateinit var factory: CheckpointSerializationFactory + private lateinit var context: CheckpointSerializationContext + private lateinit var contextWithToken: CheckpointSerializationContext private val mockServices = MockServices(emptyList(), CordaX500Name("MegaCorp", "London", "GB"), rigorousMock()) @Before fun setup() { - factory = testSerialization.serializationFactory - context = testSerialization.checkpointContext - contextWithToken = context.withTokenContext(SerializeAsTokenContextImpl(Any(), factory, context, mockServices)) + factory = testCheckpointSerialization.checkpointSerializationFactory + context = testCheckpointSerialization.checkpointSerializationContext + contextWithToken = context.withTokenContext(CheckpointSerializeAsTokenContextImpl(Any(), factory, context, mockServices)) } @Test fun `write contract attachment and read it back`() { val contractAttachment = ContractAttachment(GeneratedAttachment(EMPTY_BYTE_ARRAY), DummyContract.PROGRAM_ID) // no token context so will serialize the whole attachment - val serialized = contractAttachment.serialize(factory, context) - val deserialized = serialized.deserialize(factory, context) + val serialized = contractAttachment.checkpointSerialize(factory, context) + val deserialized = serialized.checkpointDeserialize(factory, context) assertEquals(contractAttachment.id, deserialized.attachment.id) assertEquals(contractAttachment.contract, deserialized.contract) @@ -53,8 +59,8 @@ class ContractAttachmentSerializerTest { mockServices.attachments.importAttachment(attachment.open(), "test", null) val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) - val serialized = contractAttachment.serialize(factory, contextWithToken) - val deserialized = serialized.deserialize(factory, contextWithToken) + val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken) + val deserialized = serialized.checkpointDeserialize(factory, contextWithToken) assertEquals(contractAttachment.id, deserialized.attachment.id) assertEquals(contractAttachment.contract, deserialized.contract) @@ -70,7 +76,7 @@ class ContractAttachmentSerializerTest { mockServices.attachments.importAttachment(attachment.open(), "test", null) val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) - val serialized = contractAttachment.serialize(factory, contextWithToken) + val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken) assertThat(serialized.size).isLessThan(largeAttachmentSize) } @@ -82,8 +88,8 @@ class ContractAttachmentSerializerTest { // don't importAttachment in mockService val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) - val serialized = contractAttachment.serialize(factory, contextWithToken) - val deserialized = serialized.deserialize(factory, contextWithToken) + val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken) + val deserialized = serialized.checkpointDeserialize(factory, contextWithToken) assertThatThrownBy { deserialized.attachment.open() }.isInstanceOf(MissingAttachmentsException::class.java) } @@ -94,8 +100,8 @@ class ContractAttachmentSerializerTest { // don't importAttachment in mockService val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) - val serialized = contractAttachment.serialize(factory, contextWithToken) - serialized.deserialize(factory, contextWithToken) + val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken) + serialized.checkpointDeserialize(factory, contextWithToken) // MissingAttachmentsException thrown if we try to open attachment } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/CordaClassResolverTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/CordaClassResolverTests.kt index f78752577a..860a04a81c 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/CordaClassResolverTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/CordaClassResolverTests.kt @@ -11,12 +11,11 @@ import com.nhaarman.mockito_kotlin.verify import com.nhaarman.mockito_kotlin.whenever import net.corda.core.internal.DEPLOYED_CORDAPP_UPLOADER import net.corda.core.node.services.AttachmentStorage +import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.CordaSerializable -import net.corda.core.serialization.SerializationContext import net.corda.node.serialization.kryo.CordaClassResolver import net.corda.node.serialization.kryo.CordaKryo -import net.corda.node.serialization.kryo.kryoMagic import net.corda.testing.internal.rigorousMock import net.corda.testing.services.MockAttachmentStorage import org.junit.Rule @@ -115,8 +114,8 @@ class CordaClassResolverTests { val emptyMapClass = mapOf().javaClass } - private val emptyWhitelistContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P, null) - private val allButBlacklistedContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P, null) + private val emptyWhitelistContext: CheckpointSerializationContext = CheckpointSerializationContextImpl(this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, null) + private val allButBlacklistedContext: CheckpointSerializationContext = CheckpointSerializationContextImpl(this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, null) @Test fun `Annotation on enum works for specialised entries`() { CordaClassResolver(emptyWhitelistContext).getRegistration(Foo.Bar::class.java) diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/PrivateKeySerializationTest.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/PrivateKeySerializationTest.kt index 3b1d46f342..bdd5b672ef 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/PrivateKeySerializationTest.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/PrivateKeySerializationTest.kt @@ -3,6 +3,8 @@ package net.corda.serialization.internal import net.corda.core.crypto.Crypto import net.corda.core.serialization.SerializationContext.UseCase.* import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.internal.CheckpointSerializationDefaults +import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.serialization.serialize import net.corda.testing.core.SerializationEnvironmentRule import org.assertj.core.api.Assertions.assertThatThrownBy @@ -33,13 +35,13 @@ class PrivateKeySerializationTest(private val privateKey: PrivateKey, private va @Test fun `passed with expected UseCases`() { assertTrue { privateKey.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes.isNotEmpty() } - assertTrue { privateKey.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT).bytes.isNotEmpty() } + assertTrue { privateKey.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT).bytes.isNotEmpty() } } @Test fun `failed with wrong UseCase`() { assertThatThrownBy { privateKey.serialize(context = SerializationDefaults.P2P_CONTEXT) } .isInstanceOf(IllegalStateException::class.java) - .hasMessageContaining("UseCase '$P2P' is not within") + .hasMessageContaining("UseCase '$P2P' is not 'Storage") } } \ No newline at end of file diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/SerializationTokenTest.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/SerializationTokenTest.kt index b18e1d725b..7f2bad6854 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/SerializationTokenTest.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/SerializationTokenTest.kt @@ -4,6 +4,10 @@ import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.io.Output import net.corda.core.serialization.* +import net.corda.core.serialization.internal.CheckpointSerializationContext +import net.corda.core.serialization.internal.CheckpointSerializationFactory +import net.corda.core.serialization.internal.checkpointDeserialize +import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.utilities.OpaqueBytes import net.corda.node.serialization.kryo.CordaClassResolver import net.corda.node.serialization.kryo.CordaKryo @@ -11,6 +15,7 @@ import net.corda.node.serialization.kryo.DefaultKryoCustomizer import net.corda.node.serialization.kryo.kryoMagic import net.corda.testing.internal.rigorousMock import net.corda.testing.core.SerializationEnvironmentRule +import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Rule @@ -18,16 +23,18 @@ import org.junit.Test import java.io.ByteArrayOutputStream class SerializationTokenTest { + @Rule @JvmField - val testSerialization = SerializationEnvironmentRule() - private lateinit var factory: SerializationFactory - private lateinit var context: SerializationContext + val testCheckpointSerialization = CheckpointSerializationEnvironmentRule() + + private lateinit var factory: CheckpointSerializationFactory + private lateinit var context: CheckpointSerializationContext @Before fun setup() { - factory = testSerialization.serializationFactory - context = testSerialization.checkpointContext.withWhitelisted(SingletonSerializationToken::class.java) + factory = testCheckpointSerialization.checkpointSerializationFactory + context = testCheckpointSerialization.checkpointSerializationContext.withWhitelisted(SingletonSerializationToken::class.java) } // Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized @@ -42,16 +49,16 @@ class SerializationTokenTest { override fun equals(other: Any?) = other is LargeTokenizable && other.bytes.size == this.bytes.size } - private fun serializeAsTokenContext(toBeTokenized: Any) = SerializeAsTokenContextImpl(toBeTokenized, factory, context, rigorousMock()) + private fun serializeAsTokenContext(toBeTokenized: Any) = CheckpointSerializeAsTokenContextImpl(toBeTokenized, factory, context, rigorousMock()) @Test fun `write token and read tokenizable`() { val tokenizableBefore = LargeTokenizable() val context = serializeAsTokenContext(tokenizableBefore) val testContext = this.context.withTokenContext(context) - val serializedBytes = tokenizableBefore.serialize(factory, testContext) + val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext) assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes) - val tokenizableAfter = serializedBytes.deserialize(factory, testContext) + val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext) assertThat(tokenizableAfter).isSameAs(tokenizableBefore) } @@ -62,8 +69,8 @@ class SerializationTokenTest { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(tokenizableBefore) val testContext = this.context.withTokenContext(context) - val serializedBytes = tokenizableBefore.serialize(factory, testContext) - val tokenizableAfter = serializedBytes.deserialize(factory, testContext) + val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext) + val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext) assertThat(tokenizableAfter).isSameAs(tokenizableBefore) } @@ -72,7 +79,7 @@ class SerializationTokenTest { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(emptyList()) val testContext = this.context.withTokenContext(context) - tokenizableBefore.serialize(factory, testContext) + tokenizableBefore.checkpointSerialize(factory, testContext) } @Test(expected = UnsupportedOperationException::class) @@ -80,14 +87,14 @@ class SerializationTokenTest { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(emptyList()) val testContext = this.context.withTokenContext(context) - val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList())).serialize(factory, testContext) - serializedBytes.deserialize(factory, testContext) + val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList())).checkpointSerialize(factory, testContext) + serializedBytes.checkpointDeserialize(factory, testContext) } @Test(expected = KryoException::class) fun `no context set`() { val tokenizableBefore = UnitSerializeAsToken() - tokenizableBefore.serialize(factory, context) + tokenizableBefore.checkpointSerialize(factory, context) } @Test(expected = KryoException::class) @@ -105,7 +112,7 @@ class SerializationTokenTest { kryo.writeObject(it, emptyList()) } val serializedBytes = SerializedBytes(stream.toByteArray()) - serializedBytes.deserialize(factory, testContext) + serializedBytes.checkpointDeserialize(factory, testContext) } private class WrongTypeSerializeAsToken : SerializeAsToken { @@ -121,7 +128,7 @@ class SerializationTokenTest { val tokenizableBefore = WrongTypeSerializeAsToken() val context = serializeAsTokenContext(tokenizableBefore) val testContext = this.context.withTokenContext(context) - val serializedBytes = tokenizableBefore.serialize(factory, testContext) - serializedBytes.deserialize(factory, testContext) + val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext) + serializedBytes.checkpointDeserialize(factory, testContext) } } diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/core/SerializationTestHelpers.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/core/SerializationTestHelpers.kt index d896710d26..514b23a855 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/core/SerializationTestHelpers.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/core/SerializationTestHelpers.kt @@ -5,6 +5,7 @@ import com.nhaarman.mockito_kotlin.doAnswer import com.nhaarman.mockito_kotlin.whenever import net.corda.core.DoNotImplement import net.corda.core.internal.staticField +import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.core.serialization.internal.effectiveSerializationEnv import net.corda.testing.common.internal.asContextEnv @@ -45,7 +46,6 @@ class SerializationEnvironmentRule(private val inheritable: Boolean = false) : T private lateinit var env: SerializationEnvironment val serializationFactory get() = env.serializationFactory - val checkpointContext get() = env.checkpointContext override fun apply(base: Statement, description: Description): Statement { init(description.toString()) diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/core/internal/CheckpointSerializationTestHelpers.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/core/internal/CheckpointSerializationTestHelpers.kt new file mode 100644 index 0000000000..eb92d12cf6 --- /dev/null +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/core/internal/CheckpointSerializationTestHelpers.kt @@ -0,0 +1,71 @@ +package net.corda.testing.core.internal + +import com.nhaarman.mockito_kotlin.any +import com.nhaarman.mockito_kotlin.doAnswer +import com.nhaarman.mockito_kotlin.whenever +import net.corda.core.internal.staticField +import net.corda.core.serialization.internal.SerializationEnvironment +import net.corda.core.serialization.internal.effectiveSerializationEnv +import net.corda.testing.common.internal.asContextEnv +import net.corda.testing.core.SerializationEnvironmentRule +import net.corda.testing.internal.createTestSerializationEnv +import net.corda.testing.internal.inVMExecutors +import net.corda.testing.internal.rigorousMock +import net.corda.testing.internal.testThreadFactory +import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnector +import org.junit.rules.TestRule +import org.junit.runner.Description +import org.junit.runners.model.Statement +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors + +/** + * A test checkpoint serialization rule implementation for use in tests. + * + * @param inheritable whether new threads inherit the environment, use sparingly. + */ +class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule { + companion object { + init { + // Can't turn it off, and it creates threads that do serialization, so hack it: + InVMConnector::class.staticField("threadPoolExecutor").value = rigorousMock().also { + doAnswer { + inVMExecutors.computeIfAbsent(effectiveSerializationEnv) { + Executors.newCachedThreadPool(testThreadFactory(true)) // Close enough to what InVMConnector makes normally. + }.execute(it.arguments[0] as Runnable) + }.whenever(it).execute(any()) + } + } + + /** Do not call, instead use [SerializationEnvironmentRule] as a [org.junit.Rule]. */ + fun run(taskLabel: String, task: (SerializationEnvironment) -> T): T { + return CheckpointSerializationEnvironmentRule().apply { init(taskLabel) }.runTask(task) + } + } + + + private lateinit var env: SerializationEnvironment + + override fun apply(base: Statement, description: Description): Statement { + init(description.toString()) + return object : Statement() { + override fun evaluate() = runTask { base.evaluate() } + } + } + + private fun init(envLabel: String) { + env = createTestSerializationEnv(envLabel) + } + + private fun runTask(task: (SerializationEnvironment) -> T): T { + try { + return env.asContextEnv(inheritable, task) + } finally { + inVMExecutors.remove(env) + } + } + + val checkpointSerializationFactory get() = env.checkpointSerializationFactory + val checkpointSerializationContext get() = env.checkpointContext + +} diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalSerializationTestHelpers.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalSerializationTestHelpers.kt index 5a8e4c83a1..53bee6f798 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalSerializationTestHelpers.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalSerializationTestHelpers.kt @@ -4,10 +4,11 @@ import com.nhaarman.mockito_kotlin.doNothing import com.nhaarman.mockito_kotlin.whenever import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme import net.corda.core.DoNotImplement +import net.corda.core.serialization.internal.CheckpointSerializationFactory import net.corda.core.serialization.internal.* import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT -import net.corda.node.serialization.kryo.KryoServerSerializationScheme +import net.corda.node.serialization.kryo.KryoSerializationScheme import net.corda.serialization.internal.* import net.corda.testing.core.SerializationEnvironmentRule import java.util.concurrent.ConcurrentHashMap @@ -33,8 +34,6 @@ internal fun createTestSerializationEnv(label: String): SerializationEnvironment val factory = SerializationFactoryImpl().apply { registerScheme(AMQPClientSerializationScheme(emptyList())) registerScheme(AMQPServerSerializationScheme(emptyList())) - // needed for checkpointing - registerScheme(KryoServerSerializationScheme()) } return object : SerializationEnvironmentImpl( factory, @@ -42,7 +41,8 @@ internal fun createTestSerializationEnv(label: String): SerializationEnvironment AMQP_RPC_SERVER_CONTEXT, AMQP_RPC_CLIENT_CONTEXT, AMQP_STORAGE_CONTEXT, - KRYO_CHECKPOINT_CONTEXT + KRYO_CHECKPOINT_CONTEXT, + CheckpointSerializationFactory(KryoSerializationScheme) ) { override fun toString() = "testSerializationEnv($label)" } diff --git a/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt b/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt index 84e054876e..570a9cbe2e 100644 --- a/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt +++ b/tools/blobinspector/src/main/kotlin/net/corda/blobinspector/BlobInspector.kt @@ -17,10 +17,7 @@ import net.corda.core.serialization.internal._contextSerializationEnv import net.corda.core.utilities.base64ToByteArray import net.corda.core.utilities.hexToByteArray import net.corda.core.utilities.sequence -import net.corda.serialization.internal.AMQP_P2P_CONTEXT -import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT -import net.corda.serialization.internal.CordaSerializationMagic -import net.corda.serialization.internal.SerializationFactoryImpl +import net.corda.serialization.internal.* import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme import net.corda.serialization.internal.amqp.DeserializationInput import net.corda.serialization.internal.amqp.amqpMagic diff --git a/tools/network-bootstrapper/src/main/kotlin/net/corda/bootstrapper/serialization/SerializationHelper.kt b/tools/network-bootstrapper/src/main/kotlin/net/corda/bootstrapper/serialization/SerializationHelper.kt index 8e350c8c14..ce978e4131 100644 --- a/tools/network-bootstrapper/src/main/kotlin/net/corda/bootstrapper/serialization/SerializationHelper.kt +++ b/tools/network-bootstrapper/src/main/kotlin/net/corda/bootstrapper/serialization/SerializationHelper.kt @@ -3,6 +3,7 @@ package net.corda.bootstrapper.serialization import net.corda.core.serialization.internal.SerializationEnvironmentImpl import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.node.serialization.amqp.AMQPServerSerializationScheme +import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT import net.corda.serialization.internal.AMQP_P2P_CONTEXT import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT import net.corda.serialization.internal.SerializationFactoryImpl @@ -20,7 +21,7 @@ class SerializationEngine { p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), rpcServerContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader), - checkpointContext = AMQP_P2P_CONTEXT.withClassLoader(classloader) + checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader) ) } }