CORDA-1391: Separate out Checkpoint serialization (#3922)

* Separate out Checkpoint serialization

* Update kdocs

* Rename checkpoint serialization extension methods

* Fix bungled rename

* Limit API changes

* Simplify CheckpointSerializationFactory

* Add CheckpointSerializationScheme to API checker

* CheckpointSerializationScheme should not be implemented

* Move checkpoint serialisation to internal package

* Remove CheckpointSerializationScheme from api-current

* Quarantine internal classes

* Remove checkpoint context from public API

* Remove checkpoint context from public API

* Fix test failures

* Completely decouple SerializationTestHelpers and CheckpointSerializationTestHelpers

* Remove CHECKPOINT use case

* Remove stray reference to checkpoint use case

* Fix broken test
This commit is contained in:
Dominic Fox 2018-09-19 14:23:29 +01:00 committed by GitHub
parent d10892c09e
commit 98c92ef16f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 677 additions and 253 deletions

View File

@ -4445,8 +4445,6 @@ public interface net.corda.core.serialization.SerializationCustomSerializer
public abstract PROXY toProxy(OBJ)
##
public final class net.corda.core.serialization.SerializationDefaults extends java.lang.Object
@NotNull
public final net.corda.core.serialization.SerializationContext getCHECKPOINT_CONTEXT()
@NotNull
public final net.corda.core.serialization.SerializationContext getP2P_CONTEXT()
@NotNull
@ -6883,8 +6881,6 @@ public final class net.corda.testing.core.SerializationEnvironmentRule extends j
@NotNull
public org.junit.runners.model.Statement apply(org.junit.runners.model.Statement, org.junit.runner.Description)
@NotNull
public final net.corda.core.serialization.SerializationContext getCheckpointContext()
@NotNull
public final net.corda.core.serialization.SerializationFactory getSerializationFactory()
public static final net.corda.testing.core.SerializationEnvironmentRule$Companion Companion
##

View File

@ -50,6 +50,7 @@ task patchCore(type: Zip, dependsOn: coreJarTask) {
from(zipTree(originalJar)) {
exclude 'net/corda/core/internal/*ToggleField*.class'
exclude 'net/corda/core/serialization/*SerializationFactory*.class'
exclude 'net/corda/core/serialization/internal/CheckpointSerializationFactory*.class'
}
reproducibleFileOrder = true

View File

@ -0,0 +1,74 @@
package net.corda.core.serialization.internal
import net.corda.core.KeepForDJVM
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.ByteSequence
import java.io.NotSerializableException
/**
* A deterministic version of [CheckpointSerializationFactory] that does not use thread-locals to manage serialization
* context.
*/
@KeepForDJVM
class CheckpointSerializationFactory(
private val scheme: CheckpointSerializationScheme
) {
val defaultContext: CheckpointSerializationContext get() = _currentContext ?: effectiveSerializationEnv.checkpointContext
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
/**
* Deserialize the bytes in to an object, using the prefixed bytes to determine the format.
*
* @param byteSequence The bytes to deserialize, including a format header prefix.
* @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown.
* @param context A context that configures various parameters to deserialization.
*/
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) }
}
/**
* Serialize an object to bytes using the preferred serialization format version from the context.
*
* @param obj The object to be serialized.
* @param context A context that configures various parameters to serialization, including the serialization format version.
*/
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return withCurrentContext(context) { scheme.serialize(obj, context) }
}
override fun toString(): String {
return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}"
}
override fun equals(other: Any?): Boolean {
return other is CheckpointSerializationFactory && other.scheme == this.scheme
}
override fun hashCode(): Int = scheme.hashCode()
private var _currentContext: CheckpointSerializationContext? = null
/**
* Change the current context inside the block to that supplied.
*/
fun <T> withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T {
val priorContext = _currentContext
if (context != null) _currentContext = context
try {
return block()
} finally {
if (context != null) _currentContext = priorContext
}
}
companion object {
/**
* A default factory for serialization/deserialization.
*/
val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory
}
}

View File

@ -207,7 +207,13 @@ interface SerializationContext {
* The use case that we are serializing for, since it influences the implementations chosen.
*/
@KeepForDJVM
enum class UseCase { P2P, RPCServer, RPCClient, Storage, Checkpoint, Testing }
enum class UseCase {
P2P,
RPCServer,
RPCClient,
Storage,
Testing
}
}
/**
@ -230,7 +236,6 @@ object SerializationDefaults {
@DeleteForDJVM val RPC_SERVER_CONTEXT get() = effectiveSerializationEnv.rpcServerContext
@DeleteForDJVM val RPC_CLIENT_CONTEXT get() = effectiveSerializationEnv.rpcClientContext
@DeleteForDJVM val STORAGE_CONTEXT get() = effectiveSerializationEnv.storageContext
@DeleteForDJVM val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext
}
/**

View File

@ -0,0 +1,198 @@
package net.corda.core.serialization.internal
import net.corda.core.DeleteForDJVM
import net.corda.core.DoNotImplement
import net.corda.core.KeepForDJVM
import net.corda.core.crypto.SecureHash
import net.corda.core.serialization.*
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.sequence
import java.io.NotSerializableException
object CheckpointSerializationDefaults {
@DeleteForDJVM
val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext
val CHECKPOINT_SERIALIZATION_FACTORY get() = effectiveSerializationEnv.checkpointSerializationFactory
}
/**
* A class for serializing and deserializing objects at checkpoints, using Kryo serialization.
*/
@KeepForDJVM
class CheckpointSerializationFactory(
private val scheme: CheckpointSerializationScheme
) {
val defaultContext: CheckpointSerializationContext get() = _currentContext.get() ?: effectiveSerializationEnv.checkpointContext
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
/**
* Deserialize the bytes in to an object, using the prefixed bytes to determine the format.
*
* @param byteSequence The bytes to deserialize, including a format header prefix.
* @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown.
* @param context A context that configures various parameters to deserialization.
*/
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) }
}
/**
* Serialize an object to bytes using the preferred serialization format version from the context.
*
* @param obj The object to be serialized.
* @param context A context that configures various parameters to serialization, including the serialization format version.
*/
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return withCurrentContext(context) { scheme.serialize(obj, context) }
}
override fun toString(): String {
return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}"
}
override fun equals(other: Any?): Boolean {
return other is CheckpointSerializationFactory && other.scheme == this.scheme
}
override fun hashCode(): Int = scheme.hashCode()
private val _currentContext = ThreadLocal<CheckpointSerializationContext?>()
/**
* Change the current context inside the block to that supplied.
*/
fun <T> withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T {
val priorContext = _currentContext.get()
if (context != null) _currentContext.set(context)
try {
return block()
} finally {
if (context != null) _currentContext.set(priorContext)
}
}
companion object {
val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory
}
}
@KeepForDJVM
@DoNotImplement
interface CheckpointSerializationScheme {
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T
@Throws(NotSerializableException::class)
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T>
}
/**
* Parameters to checkpoint serialization and deserialization.
*/
@KeepForDJVM
@DoNotImplement
interface CheckpointSerializationContext {
/**
* If non-null, apply this encoding (typically compression) when serializing.
*/
val encoding: SerializationEncoding?
/**
* The class loader to use for deserialization.
*/
val deserializationClassLoader: ClassLoader
/**
* A whitelist that contains (mostly for security purposes) which classes can be serialized and deserialized.
*/
val whitelist: ClassWhitelist
/**
* A whitelist that determines (mostly for security purposes) whether a particular encoding may be used when deserializing.
*/
val encodingWhitelist: EncodingWhitelist
/**
* A map of any addition properties specific to the particular use case.
*/
val properties: Map<Any, Any>
/**
* Duplicate references to the same object preserved in the wire format and when deserialized when this is true,
* otherwise they appear as new copies of the object.
*/
val objectReferencesEnabled: Boolean
/**
* Helper method to return a new context based on this context with the property added.
*/
fun withProperty(property: Any, value: Any): CheckpointSerializationContext
/**
* Helper method to return a new context based on this context with object references disabled.
*/
fun withoutReferences(): CheckpointSerializationContext
/**
* Helper method to return a new context based on this context with the deserialization class loader changed.
*/
fun withClassLoader(classLoader: ClassLoader): CheckpointSerializationContext
/**
* Helper method to return a new context based on this context with the appropriate class loader constructed from the passed attachment identifiers.
* (Requires the attachment storage to have been enabled).
*/
@Throws(MissingAttachmentsException::class)
fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): CheckpointSerializationContext
/**
* Helper method to return a new context based on this context with the given class specifically whitelisted.
*/
fun withWhitelisted(clazz: Class<*>): CheckpointSerializationContext
/**
* A shallow copy of this context but with the given (possibly null) encoding.
*/
fun withEncoding(encoding: SerializationEncoding?): CheckpointSerializationContext
/**
* A shallow copy of this context but with the given encoding whitelist.
*/
fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist): CheckpointSerializationContext
}
/*
* The following extension methods are disambiguated from the AMQP-serialization methods by requiring that an
* explicit [CheckpointSerializationContext] parameter be provided.
*/
/*
* Convenience extension method for deserializing a ByteSequence, utilising the default factory.
*/
inline fun <reified T : Any> ByteSequence.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): T {
return serializationFactory.deserialize(this, T::class.java, context)
}
/**
* Convenience extension method for deserializing SerializedBytes with type matching, utilising the default factory.
*/
inline fun <reified T : Any> SerializedBytes<T>.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): T {
return serializationFactory.deserialize(this, T::class.java, context)
}
/**
* Convenience extension method for deserializing a ByteArray, utilising the default factory.
*/
inline fun <reified T : Any> ByteArray.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): T {
require(isNotEmpty()) { "Empty bytes" }
return this.sequence().checkpointDeserialize(serializationFactory, context)
}
/**
* Convenience extension method for serializing an object of type T, utilising the default factory.
*/
fun <T : Any> T.checkpointSerialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): SerializedBytes<T> {
return serializationFactory.serialize(this, context)
}

View File

@ -12,11 +12,12 @@ import net.corda.core.serialization.SerializationFactory
@KeepForDJVM
interface SerializationEnvironment {
val serializationFactory: SerializationFactory
val checkpointSerializationFactory: CheckpointSerializationFactory
val p2pContext: SerializationContext
val rpcServerContext: SerializationContext
val rpcClientContext: SerializationContext
val storageContext: SerializationContext
val checkpointContext: SerializationContext
val checkpointContext: CheckpointSerializationContext
}
@KeepForDJVM
@ -26,18 +27,21 @@ open class SerializationEnvironmentImpl(
rpcServerContext: SerializationContext? = null,
rpcClientContext: SerializationContext? = null,
storageContext: SerializationContext? = null,
checkpointContext: SerializationContext? = null) : SerializationEnvironment {
checkpointContext: CheckpointSerializationContext? = null,
checkpointSerializationFactory: CheckpointSerializationFactory? = null) : SerializationEnvironment {
// Those that are passed in as null are never inited:
override lateinit var rpcServerContext: SerializationContext
override lateinit var rpcClientContext: SerializationContext
override lateinit var storageContext: SerializationContext
override lateinit var checkpointContext: SerializationContext
override lateinit var checkpointContext: CheckpointSerializationContext
override lateinit var checkpointSerializationFactory: CheckpointSerializationFactory
init {
rpcServerContext?.let { this.rpcServerContext = it }
rpcClientContext?.let { this.rpcClientContext = it }
storageContext?.let { this.storageContext = it }
checkpointContext?.let { this.checkpointContext = it }
checkpointSerializationFactory?.let { this.checkpointSerializationFactory = it }
}
}

View File

@ -1,11 +1,14 @@
package net.corda.core.flows;
import net.corda.core.serialization.internal.CheckpointSerializationDefaults;
import net.corda.core.serialization.internal.CheckpointSerializationFactory;
import net.corda.core.serialization.SerializationDefaults;
import net.corda.core.serialization.SerializationFactory;
import net.corda.testing.core.SerializationEnvironmentRule;
import org.junit.Rule;
import org.junit.Test;
import static net.corda.core.serialization.internal.CheckpointSerializationAPIKt.checkpointSerialize;
import static net.corda.core.serialization.SerializationAPIKt.serialize;
import static org.junit.Assert.assertNull;
@ -28,10 +31,13 @@ public class SerializationApiInJavaTest {
public void enforceSerializationDefaultsApi() {
SerializationDefaults defaults = SerializationDefaults.INSTANCE;
SerializationFactory factory = defaults.getSERIALIZATION_FACTORY();
CheckpointSerializationDefaults checkpointDefaults = CheckpointSerializationDefaults.INSTANCE;
CheckpointSerializationFactory checkpointSerializationFactory = checkpointDefaults.getCHECKPOINT_SERIALIZATION_FACTORY();
serialize("hello", factory, defaults.getP2P_CONTEXT());
serialize("hello", factory, defaults.getRPC_SERVER_CONTEXT());
serialize("hello", factory, defaults.getRPC_CLIENT_CONTEXT());
serialize("hello", factory, defaults.getSTORAGE_CONTEXT());
serialize("hello", factory, defaults.getCHECKPOINT_CONTEXT());
checkpointSerialize("hello", checkpointSerializationFactory, checkpointDefaults.getCHECKPOINT_CONTEXT());
}
}

View File

@ -3,9 +3,10 @@ package net.corda.core.utilities
import com.esotericsoftware.kryo.KryoException
import net.corda.core.crypto.random63BitValue
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.kryoMagic
import net.corda.serialization.internal.SerializationContextImpl
import net.corda.serialization.internal.CheckpointSerializationContextImpl
import net.corda.testing.core.SerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThat
import org.junit.Rule
@ -24,12 +25,11 @@ class KotlinUtilsTest {
@Rule
val expectedEx: ExpectedException = ExpectedException.none()
private val KRYO_CHECKPOINT_NOWHITELIST_CONTEXT = SerializationContextImpl(kryoMagic,
private val KRYO_CHECKPOINT_NOWHITELIST_CONTEXT = CheckpointSerializationContextImpl(
javaClass.classLoader,
EmptyWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.Checkpoint,
null)
@Test
@ -44,7 +44,7 @@ class KotlinUtilsTest {
fun `checkpointing a transient property with non-capturing lambda`() {
val original = NonCapturingTransientProperty()
val originalVal = original.transientVal
val copy = original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_CONTEXT)
val copy = original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_CONTEXT)
val copyVal = copy.transientVal
assertThat(copyVal).isNotEqualTo(originalVal)
assertThat(copy.transientVal).isEqualTo(copyVal)
@ -55,14 +55,14 @@ class KotlinUtilsTest {
expectedEx.expect(KryoException::class.java)
expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization")
val original = NonCapturingTransientProperty()
original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT)
original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT)
}
@Test
fun `checkpointing a transient property with capturing lambda`() {
val original = CapturingTransientProperty("Hello")
val originalVal = original.transientVal
val copy = original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_CONTEXT)
val copy = original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_CONTEXT)
val copyVal = copy.transientVal
assertThat(copyVal).isNotEqualTo(originalVal)
assertThat(copy.transientVal).isEqualTo(copyVal)
@ -76,7 +76,7 @@ class KotlinUtilsTest {
val original = CapturingTransientProperty("Hello")
original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT)
original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT)
}
private class NullTransientProperty {

View File

@ -5,11 +5,12 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.statemachine.SubFlow
import net.corda.node.services.statemachine.SubFlowVersion
import net.corda.serialization.internal.SerializeAsTokenContextImpl
import net.corda.serialization.internal.CheckpointSerializeAsTokenContextImpl
import net.corda.serialization.internal.withTokenContext
object CheckpointVerifier {
@ -19,13 +20,13 @@ object CheckpointVerifier {
* @throws CheckpointIncompatibleException if any offending checkpoint is found.
*/
fun verifyCheckpointsCompatible(checkpointStorage: CheckpointStorage, currentCordapps: List<Cordapp>, platformVersion: Int, serviceHub: ServiceHub, tokenizableServices: List<Any>) {
val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
checkpointStorage.getAllCheckpoints().forEach { (_, serializedCheckpoint) ->
val checkpoint = try {
serializedCheckpoint.deserialize(context = checkpointSerializationContext)
serializedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext)
} catch (e: Exception) {
throw CheckpointIncompatibleException.CannotBeDeserialisedException(e)
}

View File

@ -21,6 +21,7 @@ import net.corda.core.messaging.RPCOps
import net.corda.core.node.NetworkParameters
import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
@ -37,7 +38,7 @@ import net.corda.node.internal.security.RPCSecurityManagerImpl
import net.corda.node.internal.security.RPCSecurityManagerWithAdditionalUser
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoServerSerializationScheme
import net.corda.node.serialization.kryo.KryoSerializationScheme
import net.corda.node.services.Permissions
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal
@ -449,8 +450,8 @@ open class Node(configuration: NodeConfiguration,
SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps))
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps))
registerScheme(KryoServerSerializationScheme())
},
checkpointSerializationFactory = CheckpointSerializationFactory(KryoSerializationScheme),
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),

View File

@ -8,8 +8,8 @@ import com.esotericsoftware.kryo.util.DefaultClassResolver
import com.esotericsoftware.kryo.util.Util
import net.corda.core.internal.kotlinObjectInstance
import net.corda.core.internal.writer
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializationContext
import net.corda.core.utilities.contextLogger
import net.corda.serialization.internal.AttachmentsClassLoader
import net.corda.serialization.internal.MutableClassWhitelist
@ -25,7 +25,7 @@ import java.util.*
/**
* Corda specific class resolver which enables extra customisation for the purposes of serialization using Kryo
*/
class CordaClassResolver(serializationContext: SerializationContext) : DefaultClassResolver() {
class CordaClassResolver(serializationContext: CheckpointSerializationContext) : DefaultClassResolver() {
val whitelist: ClassWhitelist = TransientClassWhiteList(serializationContext.whitelist)
// These classes are assignment-compatible Java equivalents of Kotlin classes.

View File

@ -14,12 +14,11 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.TransactionSignature
import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.UseCase.Checkpoint
import net.corda.core.serialization.SerializationContext.UseCase.Storage
import net.corda.core.serialization.SerializeAsTokenContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.*
import net.corda.core.utilities.OpaqueBytes
import net.corda.serialization.internal.checkUseCase
import net.corda.serialization.internal.serializationContextKey
import org.slf4j.Logger
import org.slf4j.LoggerFactory
@ -275,16 +274,9 @@ object SignedTransactionSerializer : Serializer<SignedTransaction>() {
}
}
sealed class UseCaseSerializer<T>(private val allowedUseCases: EnumSet<SerializationContext.UseCase>) : Serializer<T>() {
protected fun checkUseCase() {
net.corda.serialization.internal.checkUseCase(allowedUseCases)
}
}
@ThreadSafe
object PrivateKeySerializer : UseCaseSerializer<PrivateKey>(EnumSet.of(Storage, Checkpoint)) {
object PrivateKeySerializer : Serializer<PrivateKey>() {
override fun write(kryo: Kryo, output: Output, obj: PrivateKey) {
checkUseCase()
output.writeBytesWithLength(obj.encoded)
}

View File

@ -10,10 +10,9 @@ import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.esotericsoftware.kryo.serializers.ClosureSerializer
import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationScheme
import net.corda.core.utilities.ByteSequence
import net.corda.serialization.internal.*
import java.security.PublicKey
@ -32,46 +31,30 @@ private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>()
override fun read(kryo: Kryo, input: Input, type: Class<AutoCloseable>) = throw IllegalStateException("Should not reach here!")
}
abstract class AbstractKryoSerializationScheme : SerializationScheme {
object KryoSerializationScheme : CheckpointSerializationScheme {
private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>()
protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool
protected abstract fun rpcServerKryoPool(context: SerializationContext): KryoPool
// this can be overridden in derived serialization schemes
protected open val publicKeySerializer: Serializer<PublicKey> = PublicKeySerializer
private fun getPool(context: SerializationContext): KryoPool {
private fun getPool(context: CheckpointSerializationContext): KryoPool {
return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) {
when (context.useCase) {
SerializationContext.UseCase.Checkpoint ->
KryoPool.Builder {
val serializer = Fiber.getFiberSerializer(false) as KryoSerializer
val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) }
// TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that
val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true }
serializer.kryo.apply {
field.set(this, classResolver)
// don't allow overriding the public key serializer for checkpointing
DefaultKryoCustomizer.customize(this)
addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector)
register(ClosureSerializer.Closure::class.java, CordaClosureSerializer)
classLoader = it.second
}
}.build()
SerializationContext.UseCase.RPCClient ->
rpcClientKryoPool(context)
SerializationContext.UseCase.RPCServer ->
rpcServerKryoPool(context)
else ->
KryoPool.Builder {
DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(context)), publicKeySerializer).apply { classLoader = it.second }
}.build()
}
KryoPool.Builder {
val serializer = Fiber.getFiberSerializer(false) as KryoSerializer
val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) }
// TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that
val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true }
serializer.kryo.apply {
field.set(this, classResolver)
// don't allow overriding the public key serializer for checkpointing
DefaultKryoCustomizer.customize(this)
addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector)
register(ClosureSerializer.Closure::class.java, CordaClosureSerializer)
classLoader = it.second
}
}.build()
}
}
private fun <T : Any> SerializationContext.kryo(task: Kryo.() -> T): T {
private fun <T : Any> CheckpointSerializationContext.kryo(task: Kryo.() -> T): T {
return getPool(this).run { kryo ->
kryo.context.ensureCapacity(properties.size)
properties.forEach { kryo.context.put(it.key, it.value) }
@ -83,7 +66,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
}
}
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
val dataBytes = kryoMagic.consume(byteSequence)
?: throw KryoException("Serialized bytes header does not match expected format.")
return context.kryo {
@ -111,7 +94,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
}
}
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
override fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return context.kryo {
SerializedBytes(kryoOutput {
kryoMagic.writeTo(this)
@ -131,13 +114,11 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
}
}
val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl(
kryoMagic,
val KRYO_CHECKPOINT_CONTEXT = CheckpointSerializationContextImpl(
SerializationDefaults.javaClass.classLoader,
QuasarWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.Checkpoint,
null,
AlwaysAcceptEncodingWhitelist
)

View File

@ -1,14 +0,0 @@
package net.corda.node.serialization.kryo
import com.esotericsoftware.kryo.pool.KryoPool
import net.corda.core.serialization.SerializationContext
import net.corda.serialization.internal.CordaSerializationMagic
class KryoServerSerializationScheme : AbstractKryoSerializationScheme() {
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean {
return magic == kryoMagic && target == SerializationContext.UseCase.Checkpoint
}
override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
}

View File

@ -4,9 +4,9 @@ import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import com.codahale.metrics.*
import net.corda.core.internal.concurrent.thenMatch
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace
import net.corda.node.services.api.CheckpointStorage
@ -27,7 +27,7 @@ class ActionExecutorImpl(
private val checkpointStorage: CheckpointStorage,
private val flowMessaging: FlowMessaging,
private val stateMachineManager: StateMachineManagerInternal,
private val checkpointSerializationContext: SerializationContext,
private val checkpointSerializationContext: CheckpointSerializationContext,
metrics: MetricRegistry
) : ActionExecutor {
@ -237,7 +237,7 @@ class ActionExecutorImpl(
}
private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> {
return checkpoint.serialize(context = checkpointSerializationContext)
return checkpoint.checkpointSerialize(context = checkpointSerializationContext)
}
private fun cancelFlowTimeout(action: Action.CancelFlowTimeout) {

View File

@ -12,8 +12,8 @@ import net.corda.core.cordapp.Cordapp
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.internal.*
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.serialize
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.Try
import net.corda.core.utilities.debug
import net.corda.core.utilities.trace
@ -69,7 +69,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val actionExecutor: ActionExecutor,
val stateMachine: StateMachine,
val serviceHub: ServiceHubInternal,
val checkpointSerializationContext: SerializationContext,
val checkpointSerializationContext: CheckpointSerializationContext,
val unfinishedFibers: ReusableLatch
)
@ -369,7 +369,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
Event.Suspend(
ioRequest = ioRequest,
maySkipCheckpoint = skipPersistingCheckpoint,
fiber = this.serialize(context = serializationContext.value)
fiber = this.checkpointSerialize(context = serializationContext.value)
)
} catch (throwable: Throwable) {
Event.Error(throwable)

View File

@ -19,6 +19,10 @@ import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
@ -36,7 +40,7 @@ import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.injectOldProgressTracker
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
import net.corda.serialization.internal.SerializeAsTokenContextImpl
import net.corda.serialization.internal.CheckpointSerializeAsTokenContextImpl
import net.corda.serialization.internal.withTokenContext
import org.apache.activemq.artemis.utils.ReusableLatch
import rx.Observable
@ -103,7 +107,7 @@ class SingleThreadedStateMachineManager(
private val transitionExecutor = makeTransitionExecutor()
private val ourSenderUUID = serviceHub.networkService.ourSenderUUID
private var checkpointSerializationContext: SerializationContext? = null
private var checkpointSerializationContext: CheckpointSerializationContext? = null
private var actionExecutor: ActionExecutor? = null
override val allStateMachines: List<FlowLogic<*>>
@ -122,8 +126,8 @@ class SingleThreadedStateMachineManager(
override fun start(tokenizableServices: List<Any>) {
checkQuasarJavaAgentPresence()
val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
this.checkpointSerializationContext = checkpointSerializationContext
this.actionExecutor = makeActionExecutor(checkpointSerializationContext)
@ -531,7 +535,7 @@ class SingleThreadedStateMachineManager(
val resultFuture = openFuture<Any?>()
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
flowLogic.stateMachine = flowStateMachineImpl
val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!)
val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext!!)
val flowCorDappVersion = createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion)
@ -613,7 +617,7 @@ class SingleThreadedStateMachineManager(
private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes<Checkpoint>): Checkpoint? {
return try {
serializedCheckpoint.deserialize(context = checkpointSerializationContext!!)
serializedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext!!)
} catch (exception: Throwable) {
logger.error("Encountered unrestorable checkpoint!", exception)
null
@ -658,7 +662,7 @@ class SingleThreadedStateMachineManager(
val resultFuture = openFuture<Any?>()
val fiber = when (flowState) {
is FlowState.Unstarted -> {
val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!)
val logic = flowState.frozenFlowLogic.checkpointDeserialize(context = checkpointSerializationContext!!)
val state = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
@ -677,7 +681,7 @@ class SingleThreadedStateMachineManager(
fiber
}
is FlowState.Started -> {
val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!)
val fiber = flowState.frozenFiber.checkpointDeserialize(context = checkpointSerializationContext!!)
val state = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
@ -742,7 +746,7 @@ class SingleThreadedStateMachineManager(
}
}
private fun makeActionExecutor(checkpointSerializationContext: SerializationContext): ActionExecutor {
private fun makeActionExecutor(checkpointSerializationContext: CheckpointSerializationContext): ActionExecutor {
return ActionExecutorImpl(
serviceHub,
checkpointStorage,

View File

@ -2,9 +2,9 @@ package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.ActionExecutor
import net.corda.node.services.statemachine.Event
@ -68,7 +68,7 @@ class FiberDeserializationChecker {
private val jobQueue = LinkedBlockingQueue<Job>()
private var foundUnrestorableFibers: Boolean = false
fun start(checkpointSerializationContext: SerializationContext) {
fun start(checkpointSerializationContext: CheckpointSerializationContext) {
require(checkerThread == null)
checkerThread = thread(name = "FiberDeserializationChecker") {
while (true) {
@ -76,7 +76,7 @@ class FiberDeserializationChecker {
when (job) {
is Job.Check -> {
try {
job.serializedFiber.deserialize(context = checkpointSerializationContext)
job.serializedFiber.checkpointDeserialize(context = checkpointSerializationContext)
} catch (throwable: Throwable) {
log.error("Encountered unrestorable checkpoint!", throwable)
foundUnrestorableFibers = true

View File

@ -20,10 +20,10 @@ import net.corda.core.flows.StateConsumptionDetails
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.notary.isConsumedByTheSameTx
import net.corda.core.internal.notary.validateTimeWindow
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
@ -200,11 +200,11 @@ class RaftTransactionCommitLog<E, EK>(
}
class CordaKryoSerializer<T : Any> : TypeSerializer<T> {
private val context = SerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY)
private val factory = SerializationFactory.defaultFactory
private val context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY)
private val factory = CheckpointSerializationFactory.defaultFactory
override fun write(obj: T, buffer: BufferOutput<*>, serializer: Serializer) {
val serialized = obj.serialize(context = context)
val serialized = obj.checkpointSerialize(context = context)
buffer.writeInt(serialized.size)
buffer.write(serialized.bytes)
}

View File

@ -5,7 +5,6 @@ import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.KryoSerializable
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.primitives.Ints
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
@ -13,6 +12,10 @@ import net.corda.core.contracts.PrivacySalt
import net.corda.core.crypto.*
import net.corda.core.internal.FetchDataFlow
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.sequence
@ -36,16 +39,6 @@ import java.util.*
import kotlin.collections.ArrayList
import kotlin.test.*
class TestScheme : AbstractKryoSerializationScheme() {
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean {
return magic == kryoMagic && target != SerializationContext.UseCase.RPCClient
}
override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
}
@RunWith(Parameterized::class)
class KryoTests(private val compression: CordaSerializationEncoding?) {
companion object {
@ -55,18 +48,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
}
private lateinit var factory: SerializationFactory
private lateinit var context: SerializationContext
private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: CheckpointSerializationContext
@Before
fun setup() {
factory = SerializationFactoryImpl().apply { registerScheme(TestScheme()) }
context = SerializationContextImpl(kryoMagic,
factory = CheckpointSerializationFactory(KryoSerializationScheme)
context = CheckpointSerializationContextImpl(
javaClass.classLoader,
AllWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.Storage,
compression,
rigorousMock<EncodingWhitelist>().also {
if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression)
@ -77,15 +69,15 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `simple data class`() {
val birthday = Instant.parse("1984-04-17T00:30:00.00Z")
val mike = Person("mike", birthday)
val bits = mike.serialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(Person("mike", birthday))
val bits = mike.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("mike", birthday))
}
@Test
fun `null values`() {
val bob = Person("bob", null)
val bits = bob.serialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(Person("bob", null))
val bits = bob.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("bob", null))
}
@Test
@ -93,10 +85,10 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val noReferencesContext = context.withoutReferences()
val obj : ByteSequence = Ints.toByteArray(0x01234567).sequence()
val originalList : ArrayList<ByteSequence> = ArrayList<ByteSequence>().apply { this += obj }
val deserialisedList = originalList.serialize(factory, noReferencesContext).deserialize(factory, noReferencesContext)
val deserialisedList = originalList.checkpointSerialize(factory, noReferencesContext).checkpointDeserialize(factory, noReferencesContext)
originalList += obj
deserialisedList += obj
assertThat(deserialisedList.serialize(factory, noReferencesContext)).isEqualTo(originalList.serialize(factory, noReferencesContext))
assertThat(deserialisedList.checkpointSerialize(factory, noReferencesContext)).isEqualTo(originalList.checkpointSerialize(factory, noReferencesContext))
}
@Test
@ -113,14 +105,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
this += instant
this += instant
}
assertThat(listWithSameInstances.serialize(factory, noReferencesContext)).isEqualTo(listWithCopies.serialize(factory, noReferencesContext))
assertThat(listWithSameInstances.checkpointSerialize(factory, noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(factory, noReferencesContext))
}
@Test
fun `cyclic object graph`() {
val cyclic = Cyclic(3)
val bits = cyclic.serialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(cyclic)
val bits = cyclic.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(cyclic)
}
@Test
@ -132,7 +124,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
signature.verify(bitsToSign)
assertThatThrownBy { signature.verify(wrongBits) }
val deserialisedKeyPair = keyPair.serialize(factory, context).deserialize(factory, context)
val deserialisedKeyPair = keyPair.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign)
deserialisedSignature.verify(bitsToSign)
assertThatThrownBy { deserialisedSignature.verify(wrongBits) }
@ -140,28 +132,28 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `write and read Kotlin object singleton`() {
val serialised = TestSingleton.serialize(factory, context)
val deserialised = serialised.deserialize(factory, context)
val serialised = TestSingleton.checkpointSerialize(factory, context)
val deserialised = serialised.checkpointDeserialize(factory, context)
assertThat(deserialised).isSameAs(TestSingleton)
}
@Test
fun `check Kotlin EmptyList can be serialised`() {
val deserialisedList: List<Int> = emptyList<Int>().serialize(factory, context).deserialize(factory, context)
val deserialisedList: List<Int> = emptyList<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(0, deserialisedList.size)
assertEquals<Any>(Collections.emptyList<Int>().javaClass, deserialisedList.javaClass)
}
@Test
fun `check Kotlin EmptySet can be serialised`() {
val deserialisedSet: Set<Int> = emptySet<Int>().serialize(factory, context).deserialize(factory, context)
val deserialisedSet: Set<Int> = emptySet<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(0, deserialisedSet.size)
assertEquals<Any>(Collections.emptySet<Int>().javaClass, deserialisedSet.javaClass)
}
@Test
fun `check Kotlin EmptyMap can be serialised`() {
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().serialize(factory, context).deserialize(factory, context)
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(0, deserialisedMap.size)
assertEquals<Any>(Collections.emptyMap<Int, Int>().javaClass, deserialisedMap.javaClass)
}
@ -169,7 +161,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `InputStream serialisation`() {
val rubbish = ByteArray(12345) { (it * it * 0.12345).toByte() }
val readRubbishStream: InputStream = rubbish.inputStream().serialize(factory, context).deserialize(factory, context)
val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
@ -179,7 +171,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `InputStream serialisation does not write trailing garbage`() {
val byteArrays = listOf("123", "456").map { it.toByteArray() }
val streams = byteArrays.map { it.inputStream() }.serialize(factory, context).deserialize(factory, context).iterator()
val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(factory, context).checkpointDeserialize(factory, context).iterator()
byteArrays.forEach { assertArrayEquals(it, streams.next().readBytes()) }
assertFalse(streams.hasNext())
}
@ -190,16 +182,16 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val testBytes = testString.toByteArray()
val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID))
val serializedMetaData = meta.serialize(factory, context).bytes
val meta2 = serializedMetaData.deserialize<SignableData>(factory, context)
val serializedMetaData = meta.checkpointSerialize(factory, context).bytes
val meta2 = serializedMetaData.checkpointDeserialize<SignableData>(factory, context)
assertEquals(meta2, meta)
}
@Test
fun `serialize - deserialize Logger`() {
val storageContext: SerializationContext = context // TODO: make it storage context
val storageContext: CheckpointSerializationContext = context
val logger = LoggerFactory.getLogger("aName")
val logger2 = logger.serialize(factory, storageContext).deserialize(factory, storageContext)
val logger2 = logger.checkpointSerialize(factory, storageContext).checkpointDeserialize(factory, storageContext)
assertEquals(logger.name, logger2.name)
assertTrue(logger === logger2)
}
@ -211,7 +203,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
SecureHash.sha256(rubbish),
rubbish.size,
rubbish.inputStream()
).serialize(factory, context).deserialize(factory, context)
).checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
@ -238,8 +230,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32
))
val serializedBytes = expected.serialize(factory, context)
val actual = serializedBytes.deserialize(factory, context)
val serializedBytes = expected.checkpointSerialize(factory, context)
val actual = serializedBytes.checkpointDeserialize(factory, context)
assertEquals(expected, actual)
}
@ -286,15 +278,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
}
}
Tmp()
val factory = SerializationFactoryImpl().apply { registerScheme(TestScheme()) }
val context = SerializationContextImpl(kryoMagic,
val factory = CheckpointSerializationFactory(KryoSerializationScheme)
val context = CheckpointSerializationContextImpl(
javaClass.classLoader,
AllWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.P2P,
null)
pt.serialize(factory, context)
pt.checkpointSerialize(factory, context)
}
@Test
@ -302,7 +293,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val exception = IllegalArgumentException("fooBar")
val toBeSuppressedOnSenderSide = IllegalStateException("bazz1")
exception.addSuppressed(toBeSuppressedOnSenderSide)
val exception2 = exception.serialize(factory, context).deserialize(factory, context)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(exception.message, exception2.message)
assertEquals(1, exception2.suppressed.size)
@ -317,7 +308,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `serialize - deserialize Exception no suppressed`() {
val exception = IllegalArgumentException("fooBar")
val exception2 = exception.serialize(factory, context).deserialize(factory, context)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(exception.message, exception2.message)
assertEquals(0, exception2.suppressed.size)
@ -331,7 +322,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `serialize - deserialize HashNotFound`() {
val randomHash = SecureHash.randomSHA256()
val exception = FetchDataFlow.HashNotFound(randomHash)
val exception2 = exception.serialize(factory, context).deserialize(factory, context)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(randomHash, exception2.requested)
}
@ -339,17 +330,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `compression has the desired effect`() {
compression ?: return
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
val compressed = data.serialize(factory, context)
val compressed = data.checkpointSerialize(factory, context)
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
assertArrayEquals(data, compressed.deserialize(factory, context))
assertArrayEquals(data, compressed.checkpointDeserialize(factory, context))
}
@Test
fun `a particular encoding can be banned for deserialization`() {
compression ?: return
doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression)
val compressed = "whatever".serialize(factory, context)
catchThrowable { compressed.deserialize(factory, context) }.run {
val compressed = "whatever".checkpointSerialize(factory, context)
catchThrowable { compressed.checkpointDeserialize(factory, context) }.run {
assertSame<Any>(KryoException::class.java, javaClass)
assertEquals(encodingNotPermittedFormat.format(compression), message)
}
@ -360,8 +351,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
class Holder(val holder: ByteArray)
val obj = Holder(ByteArray(20000))
val uncompressedSize = obj.serialize(factory, context.withEncoding(null)).size
val compressedSize = obj.serialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
val uncompressedSize = obj.checkpointSerialize(factory, context.withEncoding(null)).size
val compressedSize = obj.checkpointSerialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
// If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade.
assertEquals(20222, uncompressedSize)
assertEquals(1111, compressedSize)

View File

@ -3,9 +3,9 @@ package net.corda.node.services.persistence
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.node.internal.CheckpointIncompatibleException
import net.corda.node.internal.CheckpointVerifier
import net.corda.node.internal.configureDatabase
@ -189,9 +189,9 @@ class DBCheckpointStorageTests {
val logic: FlowLogic<*> = object : FlowLogic<Unit>() {
override fun call() {}
}
val frozenLogic = logic.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
val frozenLogic = logic.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, ALICE, SubFlowVersion.CoreFlow(version)).getOrThrow()
return id to checkpoint.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
return id to checkpoint.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
}
}

View File

@ -0,0 +1,49 @@
package net.corda.serialization.internal
import net.corda.core.KeepForDJVM
import net.corda.core.crypto.SecureHash
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
@KeepForDJVM
data class CheckpointSerializationContextImpl @JvmOverloads constructor(
override val deserializationClassLoader: ClassLoader,
override val whitelist: ClassWhitelist,
override val properties: Map<Any, Any>,
override val objectReferencesEnabled: Boolean,
override val encoding: SerializationEncoding?,
override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) : CheckpointSerializationContext {
private val builder = AttachmentsClassLoaderBuilder(properties, deserializationClassLoader)
/**
* {@inheritDoc}
*
* We need to cache the AttachmentClassLoaders to avoid too many contexts, since the class loader is part of cache key for the context.
*/
override fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): CheckpointSerializationContext {
properties[attachmentsClassLoaderEnabledPropertyName] as? Boolean == true || return this
val classLoader = builder.build(attachmentHashes) ?: return this
return withClassLoader(classLoader)
}
override fun withProperty(property: Any, value: Any): CheckpointSerializationContext {
return copy(properties = properties + (property to value))
}
override fun withoutReferences(): CheckpointSerializationContext {
return copy(objectReferencesEnabled = false)
}
override fun withClassLoader(classLoader: ClassLoader): CheckpointSerializationContext {
return copy(deserializationClassLoader = classLoader)
}
override fun withWhitelisted(clazz: Class<*>): CheckpointSerializationContext {
return copy(whitelist = object : ClassWhitelist {
override fun hasListed(type: Class<*>): Boolean = whitelist.hasListed(type) || type.name == clazz.name
})
}
override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding)
override fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist) = copy(encodingWhitelist = encodingWhitelist)
}

View File

@ -3,14 +3,14 @@ package net.corda.serialization.internal
import net.corda.core.DeleteForDJVM
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SerializeAsTokenContext
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
val serializationContextKey = SerializeAsTokenContext::class.java
fun SerializationContext.withTokenContext(serializationContext: SerializeAsTokenContext): SerializationContext = this.withProperty(serializationContextKey, serializationContext)
fun CheckpointSerializationContext.withTokenContext(serializationContext: SerializeAsTokenContext): CheckpointSerializationContext = this.withProperty(serializationContextKey, serializationContext)
/**
* A context for mapping SerializationTokens to/from SerializeAsTokens.
@ -55,6 +55,53 @@ class SerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: Ser
}
}
override fun getSingleton(className: String) = classNameToSingleton[className]
?: throw IllegalStateException("Unable to find tokenized instance of $className in context $this")
}
/**
* A context for mapping SerializationTokens to/from SerializeAsTokens.
*
* A context is initialised with an object containing all the instances of [SerializeAsToken] to eagerly register all the tokens.
* In our case this can be the [ServiceHub].
*
* Then it is a case of using the companion object methods on [SerializeAsTokenSerializer] to set and clear context as necessary
* when serializing to enable/disable tokenization.
*/
@DeleteForDJVM
class CheckpointSerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: SerializeAsTokenContext.() -> Unit) : SerializeAsTokenContext {
constructor(toBeTokenized: Any, serializationFactory: CheckpointSerializationFactory, context: CheckpointSerializationContext, serviceHub: ServiceHub) : this(serviceHub, {
serializationFactory.serialize(toBeTokenized, context.withTokenContext(this))
})
private val classNameToSingleton = mutableMapOf<String, SerializeAsToken>()
private var readOnly = false
init {
/**
* Go ahead and eagerly serialize the object to register all of the tokens in the context.
*
* This results in the toToken() method getting called for any [SingletonSerializeAsToken] instances which
* are encountered in the object graph as they are serialized and will therefore register the token to
* object mapping for those instances. We then immediately set the readOnly flag to stop further adhoc or
* accidental registrations from occuring as these could not be deserialized in a deserialization-first
* scenario if they are not part of this iniital context construction serialization.
*/
init(this)
readOnly = true
}
override fun putSingleton(toBeTokenized: SerializeAsToken) {
val className = toBeTokenized.javaClass.name
if (className !in classNameToSingleton) {
// Only allowable if we are in SerializeAsTokenContext init (readOnly == false)
if (readOnly) {
throw UnsupportedOperationException("Attempt to write token for lazy registered $className. All tokens should be registered during context construction.")
}
classNameToSingleton[className] = toBeTokenized
}
}
override fun getSingleton(className: String) = classNameToSingleton[className]
?: throw IllegalStateException("Unable to find tokenized instance of $className in context $this")
}

View File

@ -13,3 +13,11 @@ fun checkUseCase(allowedUseCases: EnumSet<SerializationContext.UseCase>) {
throw IllegalStateException("UseCase '${currentContext.useCase}' is not within '$allowedUseCases'")
}
}
fun checkUseCase(allowedUseCase: SerializationContext.UseCase) {
val currentContext: SerializationContext = SerializationFactory.currentFactory?.currentContext
?: throw IllegalStateException("Current context is not set")
if (allowedUseCase != currentContext.useCase) {
throw IllegalStateException("UseCase '${currentContext.useCase}' is not '$allowedUseCase'")
}
}

View File

@ -163,8 +163,6 @@ abstract class AbstractAMQPSerializationScheme(
return synchronized(serializerFactoriesForContexts) {
serializerFactoriesForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) {
when (context.useCase) {
SerializationContext.UseCase.Checkpoint ->
throw IllegalStateException("AMQP should not be used for checkpoint serialization.")
SerializationContext.UseCase.RPCClient ->
rpcClientSerializerFactory(context)
SerializationContext.UseCase.RPCServer ->

View File

@ -2,7 +2,6 @@ package net.corda.serialization.internal.amqp.custom
import net.corda.core.crypto.Crypto
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.UseCase.Checkpoint
import net.corda.core.serialization.SerializationContext.UseCase.Storage
import net.corda.serialization.internal.amqp.*
import net.corda.serialization.internal.checkUseCase
@ -13,14 +12,12 @@ import java.util.*
object PrivateKeySerializer : CustomSerializer.Implements<PrivateKey>(PrivateKey::class.java) {
private val allowedUseCases = EnumSet.of(Storage, Checkpoint)
override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, descriptor, emptyList())))
override fun writeDescribedObject(obj: PrivateKey, data: Data, type: Type, output: SerializationOutput,
context: SerializationContext
) {
checkUseCase(allowedUseCases)
checkUseCase(Storage)
output.writeObject(obj.encoded, data, clazz, context)
}

View File

@ -4,7 +4,6 @@ import com.google.common.collect.Maps;
import net.corda.core.serialization.SerializationContext;
import net.corda.core.serialization.SerializationFactory;
import net.corda.core.serialization.SerializedBytes;
import net.corda.serialization.internal.amqp.AMQPNotSerializableException;
import net.corda.serialization.internal.amqp.SchemaKt;
import net.corda.testing.core.SerializationEnvironmentRule;
import org.junit.Before;
@ -20,8 +19,10 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.ThrowableAssert.catchThrowable;
public final class ForbiddenLambdaSerializationTests {
private EnumSet<SerializationContext.UseCase> contexts = EnumSet.complementOf(
EnumSet.of(SerializationContext.UseCase.Checkpoint, SerializationContext.UseCase.Testing));
EnumSet.of(SerializationContext.UseCase.Testing));
@Rule
public final SerializationEnvironmentRule testSerialization = new SerializationEnvironmentRule();
private SerializationFactory factory;

View File

@ -1,11 +1,11 @@
package net.corda.serialization.internal;
import net.corda.core.serialization.SerializationContext;
import net.corda.core.serialization.SerializationFactory;
import net.corda.core.serialization.SerializedBytes;
import net.corda.core.serialization.*;
import net.corda.core.serialization.internal.CheckpointSerializationContext;
import net.corda.core.serialization.internal.CheckpointSerializationFactory;
import net.corda.node.serialization.kryo.CordaClosureSerializer;
import net.corda.node.serialization.kryo.KryoSerializationSchemeKt;
import net.corda.testing.core.SerializationEnvironmentRule;
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@ -18,21 +18,22 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.ThrowableAssert.catchThrowable;
public final class LambdaCheckpointSerializationTest {
@Rule
public final SerializationEnvironmentRule testSerialization = new SerializationEnvironmentRule();
private SerializationFactory factory;
private SerializationContext context;
public final CheckpointSerializationEnvironmentRule testCheckpointSerialization =
new CheckpointSerializationEnvironmentRule();
private CheckpointSerializationFactory factory;
private CheckpointSerializationContext context;
@Before
public void setup() {
factory = testSerialization.getSerializationFactory();
context = new SerializationContextImpl(
KryoSerializationSchemeKt.getKryoMagic(),
factory = testCheckpointSerialization.getCheckpointSerializationFactory();
context = new CheckpointSerializationContextImpl(
getClass().getClassLoader(),
AllWhitelist.INSTANCE,
Collections.emptyMap(),
true,
SerializationContext.UseCase.Checkpoint,
null
);
}

View File

@ -3,8 +3,13 @@ package net.corda.serialization.internal
import net.corda.core.contracts.ContractAttachment
import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.testing.contracts.DummyContract
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.MockServices
import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY
@ -17,28 +22,29 @@ import org.junit.Test
import kotlin.test.assertEquals
class ContractAttachmentSerializerTest {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
val testCheckpointSerialization = CheckpointSerializationEnvironmentRule()
private lateinit var factory: SerializationFactory
private lateinit var context: SerializationContext
private lateinit var contextWithToken: SerializationContext
private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: CheckpointSerializationContext
private lateinit var contextWithToken: CheckpointSerializationContext
private val mockServices = MockServices(emptyList(), CordaX500Name("MegaCorp", "London", "GB"), rigorousMock())
@Before
fun setup() {
factory = testSerialization.serializationFactory
context = testSerialization.checkpointContext
contextWithToken = context.withTokenContext(SerializeAsTokenContextImpl(Any(), factory, context, mockServices))
factory = testCheckpointSerialization.checkpointSerializationFactory
context = testCheckpointSerialization.checkpointSerializationContext
contextWithToken = context.withTokenContext(CheckpointSerializeAsTokenContextImpl(Any(), factory, context, mockServices))
}
@Test
fun `write contract attachment and read it back`() {
val contractAttachment = ContractAttachment(GeneratedAttachment(EMPTY_BYTE_ARRAY), DummyContract.PROGRAM_ID)
// no token context so will serialize the whole attachment
val serialized = contractAttachment.serialize(factory, context)
val deserialized = serialized.deserialize(factory, context)
val serialized = contractAttachment.checkpointSerialize(factory, context)
val deserialized = serialized.checkpointDeserialize(factory, context)
assertEquals(contractAttachment.id, deserialized.attachment.id)
assertEquals(contractAttachment.contract, deserialized.contract)
@ -53,8 +59,8 @@ class ContractAttachmentSerializerTest {
mockServices.attachments.importAttachment(attachment.open(), "test", null)
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.serialize(factory, contextWithToken)
val deserialized = serialized.deserialize(factory, contextWithToken)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
val deserialized = serialized.checkpointDeserialize(factory, contextWithToken)
assertEquals(contractAttachment.id, deserialized.attachment.id)
assertEquals(contractAttachment.contract, deserialized.contract)
@ -70,7 +76,7 @@ class ContractAttachmentSerializerTest {
mockServices.attachments.importAttachment(attachment.open(), "test", null)
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.serialize(factory, contextWithToken)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
assertThat(serialized.size).isLessThan(largeAttachmentSize)
}
@ -82,8 +88,8 @@ class ContractAttachmentSerializerTest {
// don't importAttachment in mockService
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.serialize(factory, contextWithToken)
val deserialized = serialized.deserialize(factory, contextWithToken)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
val deserialized = serialized.checkpointDeserialize(factory, contextWithToken)
assertThatThrownBy { deserialized.attachment.open() }.isInstanceOf(MissingAttachmentsException::class.java)
}
@ -94,8 +100,8 @@ class ContractAttachmentSerializerTest {
// don't importAttachment in mockService
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.serialize(factory, contextWithToken)
serialized.deserialize(factory, contextWithToken)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
serialized.checkpointDeserialize(factory, contextWithToken)
// MissingAttachmentsException thrown if we try to open attachment
}

View File

@ -11,12 +11,11 @@ import com.nhaarman.mockito_kotlin.verify
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.internal.DEPLOYED_CORDAPP_UPLOADER
import net.corda.core.node.services.AttachmentStorage
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationContext
import net.corda.node.serialization.kryo.CordaClassResolver
import net.corda.node.serialization.kryo.CordaKryo
import net.corda.node.serialization.kryo.kryoMagic
import net.corda.testing.internal.rigorousMock
import net.corda.testing.services.MockAttachmentStorage
import org.junit.Rule
@ -115,8 +114,8 @@ class CordaClassResolverTests {
val emptyMapClass = mapOf<Any, Any>().javaClass
}
private val emptyWhitelistContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P, null)
private val allButBlacklistedContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P, null)
private val emptyWhitelistContext: CheckpointSerializationContext = CheckpointSerializationContextImpl(this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, null)
private val allButBlacklistedContext: CheckpointSerializationContext = CheckpointSerializationContextImpl(this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, null)
@Test
fun `Annotation on enum works for specialised entries`() {
CordaClassResolver(emptyWhitelistContext).getRegistration(Foo.Bar::class.java)

View File

@ -3,6 +3,8 @@ package net.corda.serialization.internal
import net.corda.core.crypto.Crypto
import net.corda.core.serialization.SerializationContext.UseCase.*
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.serialization.serialize
import net.corda.testing.core.SerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThatThrownBy
@ -33,13 +35,13 @@ class PrivateKeySerializationTest(private val privateKey: PrivateKey, private va
@Test
fun `passed with expected UseCases`() {
assertTrue { privateKey.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes.isNotEmpty() }
assertTrue { privateKey.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT).bytes.isNotEmpty() }
assertTrue { privateKey.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT).bytes.isNotEmpty() }
}
@Test
fun `failed with wrong UseCase`() {
assertThatThrownBy { privateKey.serialize(context = SerializationDefaults.P2P_CONTEXT) }
.isInstanceOf(IllegalStateException::class.java)
.hasMessageContaining("UseCase '$P2P' is not within")
.hasMessageContaining("UseCase '$P2P' is not 'Storage")
}
}

View File

@ -4,6 +4,10 @@ import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.io.Output
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.OpaqueBytes
import net.corda.node.serialization.kryo.CordaClassResolver
import net.corda.node.serialization.kryo.CordaKryo
@ -11,6 +15,7 @@ import net.corda.node.serialization.kryo.DefaultKryoCustomizer
import net.corda.node.serialization.kryo.kryoMagic
import net.corda.testing.internal.rigorousMock
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThat
import org.junit.Before
import org.junit.Rule
@ -18,16 +23,18 @@ import org.junit.Test
import java.io.ByteArrayOutputStream
class SerializationTokenTest {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
private lateinit var factory: SerializationFactory
private lateinit var context: SerializationContext
val testCheckpointSerialization = CheckpointSerializationEnvironmentRule()
private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: CheckpointSerializationContext
@Before
fun setup() {
factory = testSerialization.serializationFactory
context = testSerialization.checkpointContext.withWhitelisted(SingletonSerializationToken::class.java)
factory = testCheckpointSerialization.checkpointSerializationFactory
context = testCheckpointSerialization.checkpointSerializationContext.withWhitelisted(SingletonSerializationToken::class.java)
}
// Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized
@ -42,16 +49,16 @@ class SerializationTokenTest {
override fun equals(other: Any?) = other is LargeTokenizable && other.bytes.size == this.bytes.size
}
private fun serializeAsTokenContext(toBeTokenized: Any) = SerializeAsTokenContextImpl(toBeTokenized, factory, context, rigorousMock())
private fun serializeAsTokenContext(toBeTokenized: Any) = CheckpointSerializeAsTokenContextImpl(toBeTokenized, factory, context, rigorousMock())
@Test
fun `write token and read tokenizable`() {
val tokenizableBefore = LargeTokenizable()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.serialize(factory, testContext)
val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext)
assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes)
val tokenizableAfter = serializedBytes.deserialize(factory, testContext)
val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext)
assertThat(tokenizableAfter).isSameAs(tokenizableBefore)
}
@ -62,8 +69,8 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.serialize(factory, testContext)
val tokenizableAfter = serializedBytes.deserialize(factory, testContext)
val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext)
val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext)
assertThat(tokenizableAfter).isSameAs(tokenizableBefore)
}
@ -72,7 +79,7 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(emptyList<Any>())
val testContext = this.context.withTokenContext(context)
tokenizableBefore.serialize(factory, testContext)
tokenizableBefore.checkpointSerialize(factory, testContext)
}
@Test(expected = UnsupportedOperationException::class)
@ -80,14 +87,14 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(emptyList<Any>())
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList<Any>())).serialize(factory, testContext)
serializedBytes.deserialize(factory, testContext)
val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList<Any>())).checkpointSerialize(factory, testContext)
serializedBytes.checkpointDeserialize(factory, testContext)
}
@Test(expected = KryoException::class)
fun `no context set`() {
val tokenizableBefore = UnitSerializeAsToken()
tokenizableBefore.serialize(factory, context)
tokenizableBefore.checkpointSerialize(factory, context)
}
@Test(expected = KryoException::class)
@ -105,7 +112,7 @@ class SerializationTokenTest {
kryo.writeObject(it, emptyList<Any>())
}
val serializedBytes = SerializedBytes<Any>(stream.toByteArray())
serializedBytes.deserialize(factory, testContext)
serializedBytes.checkpointDeserialize(factory, testContext)
}
private class WrongTypeSerializeAsToken : SerializeAsToken {
@ -121,7 +128,7 @@ class SerializationTokenTest {
val tokenizableBefore = WrongTypeSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.serialize(factory, testContext)
serializedBytes.deserialize(factory, testContext)
val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext)
serializedBytes.checkpointDeserialize(factory, testContext)
}
}

View File

@ -5,6 +5,7 @@ import com.nhaarman.mockito_kotlin.doAnswer
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.DoNotImplement
import net.corda.core.internal.staticField
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.testing.common.internal.asContextEnv
@ -45,7 +46,6 @@ class SerializationEnvironmentRule(private val inheritable: Boolean = false) : T
private lateinit var env: SerializationEnvironment
val serializationFactory get() = env.serializationFactory
val checkpointContext get() = env.checkpointContext
override fun apply(base: Statement, description: Description): Statement {
init(description.toString())

View File

@ -0,0 +1,71 @@
package net.corda.testing.core.internal
import com.nhaarman.mockito_kotlin.any
import com.nhaarman.mockito_kotlin.doAnswer
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.internal.staticField
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.testing.common.internal.asContextEnv
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.internal.createTestSerializationEnv
import net.corda.testing.internal.inVMExecutors
import net.corda.testing.internal.rigorousMock
import net.corda.testing.internal.testThreadFactory
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnector
import org.junit.rules.TestRule
import org.junit.runner.Description
import org.junit.runners.model.Statement
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
/**
* A test checkpoint serialization rule implementation for use in tests.
*
* @param inheritable whether new threads inherit the environment, use sparingly.
*/
class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule {
companion object {
init {
// Can't turn it off, and it creates threads that do serialization, so hack it:
InVMConnector::class.staticField<ExecutorService>("threadPoolExecutor").value = rigorousMock<ExecutorService>().also {
doAnswer {
inVMExecutors.computeIfAbsent(effectiveSerializationEnv) {
Executors.newCachedThreadPool(testThreadFactory(true)) // Close enough to what InVMConnector makes normally.
}.execute(it.arguments[0] as Runnable)
}.whenever(it).execute(any())
}
}
/** Do not call, instead use [SerializationEnvironmentRule] as a [org.junit.Rule]. */
fun <T> run(taskLabel: String, task: (SerializationEnvironment) -> T): T {
return CheckpointSerializationEnvironmentRule().apply { init(taskLabel) }.runTask(task)
}
}
private lateinit var env: SerializationEnvironment
override fun apply(base: Statement, description: Description): Statement {
init(description.toString())
return object : Statement() {
override fun evaluate() = runTask { base.evaluate() }
}
}
private fun init(envLabel: String) {
env = createTestSerializationEnv(envLabel)
}
private fun <T> runTask(task: (SerializationEnvironment) -> T): T {
try {
return env.asContextEnv(inheritable, task)
} finally {
inVMExecutors.remove(env)
}
}
val checkpointSerializationFactory get() = env.checkpointSerializationFactory
val checkpointSerializationContext get() = env.checkpointContext
}

View File

@ -4,10 +4,11 @@ import com.nhaarman.mockito_kotlin.doNothing
import com.nhaarman.mockito_kotlin.whenever
import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme
import net.corda.core.DoNotImplement
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.*
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoServerSerializationScheme
import net.corda.node.serialization.kryo.KryoSerializationScheme
import net.corda.serialization.internal.*
import net.corda.testing.core.SerializationEnvironmentRule
import java.util.concurrent.ConcurrentHashMap
@ -33,8 +34,6 @@ internal fun createTestSerializationEnv(label: String): SerializationEnvironment
val factory = SerializationFactoryImpl().apply {
registerScheme(AMQPClientSerializationScheme(emptyList()))
registerScheme(AMQPServerSerializationScheme(emptyList()))
// needed for checkpointing
registerScheme(KryoServerSerializationScheme())
}
return object : SerializationEnvironmentImpl(
factory,
@ -42,7 +41,8 @@ internal fun createTestSerializationEnv(label: String): SerializationEnvironment
AMQP_RPC_SERVER_CONTEXT,
AMQP_RPC_CLIENT_CONTEXT,
AMQP_STORAGE_CONTEXT,
KRYO_CHECKPOINT_CONTEXT
KRYO_CHECKPOINT_CONTEXT,
CheckpointSerializationFactory(KryoSerializationScheme)
) {
override fun toString() = "testSerializationEnv($label)"
}

View File

@ -17,10 +17,7 @@ import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.core.utilities.base64ToByteArray
import net.corda.core.utilities.hexToByteArray
import net.corda.core.utilities.sequence
import net.corda.serialization.internal.AMQP_P2P_CONTEXT
import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT
import net.corda.serialization.internal.CordaSerializationMagic
import net.corda.serialization.internal.SerializationFactoryImpl
import net.corda.serialization.internal.*
import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme
import net.corda.serialization.internal.amqp.DeserializationInput
import net.corda.serialization.internal.amqp.amqpMagic

View File

@ -3,6 +3,7 @@ package net.corda.bootstrapper.serialization
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.serialization.internal.AMQP_P2P_CONTEXT
import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT
import net.corda.serialization.internal.SerializationFactoryImpl
@ -20,7 +21,7 @@ class SerializationEngine {
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),
checkpointContext = AMQP_P2P_CONTEXT.withClassLoader(classloader)
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader)
)
}
}