CORDA-1391: Separate out Checkpoint serialization (#3922)

* Separate out Checkpoint serialization

* Update kdocs

* Rename checkpoint serialization extension methods

* Fix bungled rename

* Limit API changes

* Simplify CheckpointSerializationFactory

* Add CheckpointSerializationScheme to API checker

* CheckpointSerializationScheme should not be implemented

* Move checkpoint serialisation to internal package

* Remove CheckpointSerializationScheme from api-current

* Quarantine internal classes

* Remove checkpoint context from public API

* Remove checkpoint context from public API

* Fix test failures

* Completely decouple SerializationTestHelpers and CheckpointSerializationTestHelpers

* Remove CHECKPOINT use case

* Remove stray reference to checkpoint use case

* Fix broken test
This commit is contained in:
Dominic Fox 2018-09-19 14:23:29 +01:00 committed by GitHub
parent d10892c09e
commit 98c92ef16f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 677 additions and 253 deletions

View File

@ -4445,8 +4445,6 @@ public interface net.corda.core.serialization.SerializationCustomSerializer
public abstract PROXY toProxy(OBJ) public abstract PROXY toProxy(OBJ)
## ##
public final class net.corda.core.serialization.SerializationDefaults extends java.lang.Object public final class net.corda.core.serialization.SerializationDefaults extends java.lang.Object
@NotNull
public final net.corda.core.serialization.SerializationContext getCHECKPOINT_CONTEXT()
@NotNull @NotNull
public final net.corda.core.serialization.SerializationContext getP2P_CONTEXT() public final net.corda.core.serialization.SerializationContext getP2P_CONTEXT()
@NotNull @NotNull
@ -6883,8 +6881,6 @@ public final class net.corda.testing.core.SerializationEnvironmentRule extends j
@NotNull @NotNull
public org.junit.runners.model.Statement apply(org.junit.runners.model.Statement, org.junit.runner.Description) public org.junit.runners.model.Statement apply(org.junit.runners.model.Statement, org.junit.runner.Description)
@NotNull @NotNull
public final net.corda.core.serialization.SerializationContext getCheckpointContext()
@NotNull
public final net.corda.core.serialization.SerializationFactory getSerializationFactory() public final net.corda.core.serialization.SerializationFactory getSerializationFactory()
public static final net.corda.testing.core.SerializationEnvironmentRule$Companion Companion public static final net.corda.testing.core.SerializationEnvironmentRule$Companion Companion
## ##

View File

@ -50,6 +50,7 @@ task patchCore(type: Zip, dependsOn: coreJarTask) {
from(zipTree(originalJar)) { from(zipTree(originalJar)) {
exclude 'net/corda/core/internal/*ToggleField*.class' exclude 'net/corda/core/internal/*ToggleField*.class'
exclude 'net/corda/core/serialization/*SerializationFactory*.class' exclude 'net/corda/core/serialization/*SerializationFactory*.class'
exclude 'net/corda/core/serialization/internal/CheckpointSerializationFactory*.class'
} }
reproducibleFileOrder = true reproducibleFileOrder = true

View File

@ -0,0 +1,74 @@
package net.corda.core.serialization.internal
import net.corda.core.KeepForDJVM
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.ByteSequence
import java.io.NotSerializableException
/**
* A deterministic version of [CheckpointSerializationFactory] that does not use thread-locals to manage serialization
* context.
*/
@KeepForDJVM
class CheckpointSerializationFactory(
private val scheme: CheckpointSerializationScheme
) {
val defaultContext: CheckpointSerializationContext get() = _currentContext ?: effectiveSerializationEnv.checkpointContext
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
/**
* Deserialize the bytes in to an object, using the prefixed bytes to determine the format.
*
* @param byteSequence The bytes to deserialize, including a format header prefix.
* @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown.
* @param context A context that configures various parameters to deserialization.
*/
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) }
}
/**
* Serialize an object to bytes using the preferred serialization format version from the context.
*
* @param obj The object to be serialized.
* @param context A context that configures various parameters to serialization, including the serialization format version.
*/
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return withCurrentContext(context) { scheme.serialize(obj, context) }
}
override fun toString(): String {
return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}"
}
override fun equals(other: Any?): Boolean {
return other is CheckpointSerializationFactory && other.scheme == this.scheme
}
override fun hashCode(): Int = scheme.hashCode()
private var _currentContext: CheckpointSerializationContext? = null
/**
* Change the current context inside the block to that supplied.
*/
fun <T> withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T {
val priorContext = _currentContext
if (context != null) _currentContext = context
try {
return block()
} finally {
if (context != null) _currentContext = priorContext
}
}
companion object {
/**
* A default factory for serialization/deserialization.
*/
val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory
}
}

View File

@ -207,7 +207,13 @@ interface SerializationContext {
* The use case that we are serializing for, since it influences the implementations chosen. * The use case that we are serializing for, since it influences the implementations chosen.
*/ */
@KeepForDJVM @KeepForDJVM
enum class UseCase { P2P, RPCServer, RPCClient, Storage, Checkpoint, Testing } enum class UseCase {
P2P,
RPCServer,
RPCClient,
Storage,
Testing
}
} }
/** /**
@ -230,7 +236,6 @@ object SerializationDefaults {
@DeleteForDJVM val RPC_SERVER_CONTEXT get() = effectiveSerializationEnv.rpcServerContext @DeleteForDJVM val RPC_SERVER_CONTEXT get() = effectiveSerializationEnv.rpcServerContext
@DeleteForDJVM val RPC_CLIENT_CONTEXT get() = effectiveSerializationEnv.rpcClientContext @DeleteForDJVM val RPC_CLIENT_CONTEXT get() = effectiveSerializationEnv.rpcClientContext
@DeleteForDJVM val STORAGE_CONTEXT get() = effectiveSerializationEnv.storageContext @DeleteForDJVM val STORAGE_CONTEXT get() = effectiveSerializationEnv.storageContext
@DeleteForDJVM val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext
} }
/** /**

View File

@ -0,0 +1,198 @@
package net.corda.core.serialization.internal
import net.corda.core.DeleteForDJVM
import net.corda.core.DoNotImplement
import net.corda.core.KeepForDJVM
import net.corda.core.crypto.SecureHash
import net.corda.core.serialization.*
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.sequence
import java.io.NotSerializableException
object CheckpointSerializationDefaults {
@DeleteForDJVM
val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext
val CHECKPOINT_SERIALIZATION_FACTORY get() = effectiveSerializationEnv.checkpointSerializationFactory
}
/**
* A class for serializing and deserializing objects at checkpoints, using Kryo serialization.
*/
@KeepForDJVM
class CheckpointSerializationFactory(
private val scheme: CheckpointSerializationScheme
) {
val defaultContext: CheckpointSerializationContext get() = _currentContext.get() ?: effectiveSerializationEnv.checkpointContext
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
/**
* Deserialize the bytes in to an object, using the prefixed bytes to determine the format.
*
* @param byteSequence The bytes to deserialize, including a format header prefix.
* @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown.
* @param context A context that configures various parameters to deserialization.
*/
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) }
}
/**
* Serialize an object to bytes using the preferred serialization format version from the context.
*
* @param obj The object to be serialized.
* @param context A context that configures various parameters to serialization, including the serialization format version.
*/
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return withCurrentContext(context) { scheme.serialize(obj, context) }
}
override fun toString(): String {
return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}"
}
override fun equals(other: Any?): Boolean {
return other is CheckpointSerializationFactory && other.scheme == this.scheme
}
override fun hashCode(): Int = scheme.hashCode()
private val _currentContext = ThreadLocal<CheckpointSerializationContext?>()
/**
* Change the current context inside the block to that supplied.
*/
fun <T> withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T {
val priorContext = _currentContext.get()
if (context != null) _currentContext.set(context)
try {
return block()
} finally {
if (context != null) _currentContext.set(priorContext)
}
}
companion object {
val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory
}
}
@KeepForDJVM
@DoNotImplement
interface CheckpointSerializationScheme {
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T
@Throws(NotSerializableException::class)
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T>
}
/**
* Parameters to checkpoint serialization and deserialization.
*/
@KeepForDJVM
@DoNotImplement
interface CheckpointSerializationContext {
/**
* If non-null, apply this encoding (typically compression) when serializing.
*/
val encoding: SerializationEncoding?
/**
* The class loader to use for deserialization.
*/
val deserializationClassLoader: ClassLoader
/**
* A whitelist that contains (mostly for security purposes) which classes can be serialized and deserialized.
*/
val whitelist: ClassWhitelist
/**
* A whitelist that determines (mostly for security purposes) whether a particular encoding may be used when deserializing.
*/
val encodingWhitelist: EncodingWhitelist
/**
* A map of any addition properties specific to the particular use case.
*/
val properties: Map<Any, Any>
/**
* Duplicate references to the same object preserved in the wire format and when deserialized when this is true,
* otherwise they appear as new copies of the object.
*/
val objectReferencesEnabled: Boolean
/**
* Helper method to return a new context based on this context with the property added.
*/
fun withProperty(property: Any, value: Any): CheckpointSerializationContext
/**
* Helper method to return a new context based on this context with object references disabled.
*/
fun withoutReferences(): CheckpointSerializationContext
/**
* Helper method to return a new context based on this context with the deserialization class loader changed.
*/
fun withClassLoader(classLoader: ClassLoader): CheckpointSerializationContext
/**
* Helper method to return a new context based on this context with the appropriate class loader constructed from the passed attachment identifiers.
* (Requires the attachment storage to have been enabled).
*/
@Throws(MissingAttachmentsException::class)
fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): CheckpointSerializationContext
/**
* Helper method to return a new context based on this context with the given class specifically whitelisted.
*/
fun withWhitelisted(clazz: Class<*>): CheckpointSerializationContext
/**
* A shallow copy of this context but with the given (possibly null) encoding.
*/
fun withEncoding(encoding: SerializationEncoding?): CheckpointSerializationContext
/**
* A shallow copy of this context but with the given encoding whitelist.
*/
fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist): CheckpointSerializationContext
}
/*
* The following extension methods are disambiguated from the AMQP-serialization methods by requiring that an
* explicit [CheckpointSerializationContext] parameter be provided.
*/
/*
* Convenience extension method for deserializing a ByteSequence, utilising the default factory.
*/
inline fun <reified T : Any> ByteSequence.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): T {
return serializationFactory.deserialize(this, T::class.java, context)
}
/**
* Convenience extension method for deserializing SerializedBytes with type matching, utilising the default factory.
*/
inline fun <reified T : Any> SerializedBytes<T>.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): T {
return serializationFactory.deserialize(this, T::class.java, context)
}
/**
* Convenience extension method for deserializing a ByteArray, utilising the default factory.
*/
inline fun <reified T : Any> ByteArray.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): T {
require(isNotEmpty()) { "Empty bytes" }
return this.sequence().checkpointDeserialize(serializationFactory, context)
}
/**
* Convenience extension method for serializing an object of type T, utilising the default factory.
*/
fun <T : Any> T.checkpointSerialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): SerializedBytes<T> {
return serializationFactory.serialize(this, context)
}

View File

@ -12,11 +12,12 @@ import net.corda.core.serialization.SerializationFactory
@KeepForDJVM @KeepForDJVM
interface SerializationEnvironment { interface SerializationEnvironment {
val serializationFactory: SerializationFactory val serializationFactory: SerializationFactory
val checkpointSerializationFactory: CheckpointSerializationFactory
val p2pContext: SerializationContext val p2pContext: SerializationContext
val rpcServerContext: SerializationContext val rpcServerContext: SerializationContext
val rpcClientContext: SerializationContext val rpcClientContext: SerializationContext
val storageContext: SerializationContext val storageContext: SerializationContext
val checkpointContext: SerializationContext val checkpointContext: CheckpointSerializationContext
} }
@KeepForDJVM @KeepForDJVM
@ -26,18 +27,21 @@ open class SerializationEnvironmentImpl(
rpcServerContext: SerializationContext? = null, rpcServerContext: SerializationContext? = null,
rpcClientContext: SerializationContext? = null, rpcClientContext: SerializationContext? = null,
storageContext: SerializationContext? = null, storageContext: SerializationContext? = null,
checkpointContext: SerializationContext? = null) : SerializationEnvironment { checkpointContext: CheckpointSerializationContext? = null,
checkpointSerializationFactory: CheckpointSerializationFactory? = null) : SerializationEnvironment {
// Those that are passed in as null are never inited: // Those that are passed in as null are never inited:
override lateinit var rpcServerContext: SerializationContext override lateinit var rpcServerContext: SerializationContext
override lateinit var rpcClientContext: SerializationContext override lateinit var rpcClientContext: SerializationContext
override lateinit var storageContext: SerializationContext override lateinit var storageContext: SerializationContext
override lateinit var checkpointContext: SerializationContext override lateinit var checkpointContext: CheckpointSerializationContext
override lateinit var checkpointSerializationFactory: CheckpointSerializationFactory
init { init {
rpcServerContext?.let { this.rpcServerContext = it } rpcServerContext?.let { this.rpcServerContext = it }
rpcClientContext?.let { this.rpcClientContext = it } rpcClientContext?.let { this.rpcClientContext = it }
storageContext?.let { this.storageContext = it } storageContext?.let { this.storageContext = it }
checkpointContext?.let { this.checkpointContext = it } checkpointContext?.let { this.checkpointContext = it }
checkpointSerializationFactory?.let { this.checkpointSerializationFactory = it }
} }
} }

View File

@ -1,11 +1,14 @@
package net.corda.core.flows; package net.corda.core.flows;
import net.corda.core.serialization.internal.CheckpointSerializationDefaults;
import net.corda.core.serialization.internal.CheckpointSerializationFactory;
import net.corda.core.serialization.SerializationDefaults; import net.corda.core.serialization.SerializationDefaults;
import net.corda.core.serialization.SerializationFactory; import net.corda.core.serialization.SerializationFactory;
import net.corda.testing.core.SerializationEnvironmentRule; import net.corda.testing.core.SerializationEnvironmentRule;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import static net.corda.core.serialization.internal.CheckpointSerializationAPIKt.checkpointSerialize;
import static net.corda.core.serialization.SerializationAPIKt.serialize; import static net.corda.core.serialization.SerializationAPIKt.serialize;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
@ -28,10 +31,13 @@ public class SerializationApiInJavaTest {
public void enforceSerializationDefaultsApi() { public void enforceSerializationDefaultsApi() {
SerializationDefaults defaults = SerializationDefaults.INSTANCE; SerializationDefaults defaults = SerializationDefaults.INSTANCE;
SerializationFactory factory = defaults.getSERIALIZATION_FACTORY(); SerializationFactory factory = defaults.getSERIALIZATION_FACTORY();
CheckpointSerializationDefaults checkpointDefaults = CheckpointSerializationDefaults.INSTANCE;
CheckpointSerializationFactory checkpointSerializationFactory = checkpointDefaults.getCHECKPOINT_SERIALIZATION_FACTORY();
serialize("hello", factory, defaults.getP2P_CONTEXT()); serialize("hello", factory, defaults.getP2P_CONTEXT());
serialize("hello", factory, defaults.getRPC_SERVER_CONTEXT()); serialize("hello", factory, defaults.getRPC_SERVER_CONTEXT());
serialize("hello", factory, defaults.getRPC_CLIENT_CONTEXT()); serialize("hello", factory, defaults.getRPC_CLIENT_CONTEXT());
serialize("hello", factory, defaults.getSTORAGE_CONTEXT()); serialize("hello", factory, defaults.getSTORAGE_CONTEXT());
serialize("hello", factory, defaults.getCHECKPOINT_CONTEXT()); checkpointSerialize("hello", checkpointSerializationFactory, checkpointDefaults.getCHECKPOINT_CONTEXT());
} }
} }

View File

@ -3,9 +3,10 @@ package net.corda.core.utilities
import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.KryoException
import net.corda.core.crypto.random63BitValue import net.corda.core.crypto.random63BitValue
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.kryoMagic import net.corda.serialization.internal.CheckpointSerializationContextImpl
import net.corda.serialization.internal.SerializationContextImpl
import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.SerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Rule import org.junit.Rule
@ -24,12 +25,11 @@ class KotlinUtilsTest {
@Rule @Rule
val expectedEx: ExpectedException = ExpectedException.none() val expectedEx: ExpectedException = ExpectedException.none()
private val KRYO_CHECKPOINT_NOWHITELIST_CONTEXT = SerializationContextImpl(kryoMagic, private val KRYO_CHECKPOINT_NOWHITELIST_CONTEXT = CheckpointSerializationContextImpl(
javaClass.classLoader, javaClass.classLoader,
EmptyWhitelist, EmptyWhitelist,
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.Checkpoint,
null) null)
@Test @Test
@ -44,7 +44,7 @@ class KotlinUtilsTest {
fun `checkpointing a transient property with non-capturing lambda`() { fun `checkpointing a transient property with non-capturing lambda`() {
val original = NonCapturingTransientProperty() val original = NonCapturingTransientProperty()
val originalVal = original.transientVal val originalVal = original.transientVal
val copy = original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_CONTEXT) val copy = original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_CONTEXT)
val copyVal = copy.transientVal val copyVal = copy.transientVal
assertThat(copyVal).isNotEqualTo(originalVal) assertThat(copyVal).isNotEqualTo(originalVal)
assertThat(copy.transientVal).isEqualTo(copyVal) assertThat(copy.transientVal).isEqualTo(copyVal)
@ -55,14 +55,14 @@ class KotlinUtilsTest {
expectedEx.expect(KryoException::class.java) expectedEx.expect(KryoException::class.java)
expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization") expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization")
val original = NonCapturingTransientProperty() val original = NonCapturingTransientProperty()
original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT) original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT)
} }
@Test @Test
fun `checkpointing a transient property with capturing lambda`() { fun `checkpointing a transient property with capturing lambda`() {
val original = CapturingTransientProperty("Hello") val original = CapturingTransientProperty("Hello")
val originalVal = original.transientVal val originalVal = original.transientVal
val copy = original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_CONTEXT) val copy = original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_CONTEXT)
val copyVal = copy.transientVal val copyVal = copy.transientVal
assertThat(copyVal).isNotEqualTo(originalVal) assertThat(copyVal).isNotEqualTo(originalVal)
assertThat(copy.transientVal).isEqualTo(copyVal) assertThat(copy.transientVal).isEqualTo(copyVal)
@ -76,7 +76,7 @@ class KotlinUtilsTest {
val original = CapturingTransientProperty("Hello") val original = CapturingTransientProperty("Hello")
original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT) original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT)
} }
private class NullTransientProperty { private class NullTransientProperty {

View File

@ -5,11 +5,12 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.deserialize import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.statemachine.SubFlow import net.corda.node.services.statemachine.SubFlow
import net.corda.node.services.statemachine.SubFlowVersion import net.corda.node.services.statemachine.SubFlowVersion
import net.corda.serialization.internal.SerializeAsTokenContextImpl import net.corda.serialization.internal.CheckpointSerializeAsTokenContextImpl
import net.corda.serialization.internal.withTokenContext import net.corda.serialization.internal.withTokenContext
object CheckpointVerifier { object CheckpointVerifier {
@ -19,13 +20,13 @@ object CheckpointVerifier {
* @throws CheckpointIncompatibleException if any offending checkpoint is found. * @throws CheckpointIncompatibleException if any offending checkpoint is found.
*/ */
fun verifyCheckpointsCompatible(checkpointStorage: CheckpointStorage, currentCordapps: List<Cordapp>, platformVersion: Int, serviceHub: ServiceHub, tokenizableServices: List<Any>) { fun verifyCheckpointsCompatible(checkpointStorage: CheckpointStorage, currentCordapps: List<Cordapp>, platformVersion: Int, serviceHub: ServiceHub, tokenizableServices: List<Any>) {
val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
) )
checkpointStorage.getAllCheckpoints().forEach { (_, serializedCheckpoint) -> checkpointStorage.getAllCheckpoints().forEach { (_, serializedCheckpoint) ->
val checkpoint = try { val checkpoint = try {
serializedCheckpoint.deserialize(context = checkpointSerializationContext) serializedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext)
} catch (e: Exception) { } catch (e: Exception) {
throw CheckpointIncompatibleException.CannotBeDeserialisedException(e) throw CheckpointIncompatibleException.CannotBeDeserialisedException(e)
} }

View File

@ -21,6 +21,7 @@ import net.corda.core.messaging.RPCOps
import net.corda.core.node.NetworkParameters import net.corda.core.node.NetworkParameters
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.SerializationEnvironmentImpl import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
@ -37,7 +38,7 @@ import net.corda.node.internal.security.RPCSecurityManagerImpl
import net.corda.node.internal.security.RPCSecurityManagerWithAdditionalUser import net.corda.node.internal.security.RPCSecurityManagerWithAdditionalUser
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoServerSerializationScheme import net.corda.node.serialization.kryo.KryoSerializationScheme
import net.corda.node.services.Permissions import net.corda.node.services.Permissions
import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
@ -449,8 +450,8 @@ open class Node(configuration: NodeConfiguration,
SerializationFactoryImpl().apply { SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps)) registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps))
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps)) registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps))
registerScheme(KryoServerSerializationScheme())
}, },
checkpointSerializationFactory = CheckpointSerializationFactory(KryoSerializationScheme),
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader), rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader), storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),

View File

@ -8,8 +8,8 @@ import com.esotericsoftware.kryo.util.DefaultClassResolver
import com.esotericsoftware.kryo.util.Util import com.esotericsoftware.kryo.util.Util
import net.corda.core.internal.kotlinObjectInstance import net.corda.core.internal.kotlinObjectInstance
import net.corda.core.internal.writer import net.corda.core.internal.writer
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializationContext
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.serialization.internal.AttachmentsClassLoader import net.corda.serialization.internal.AttachmentsClassLoader
import net.corda.serialization.internal.MutableClassWhitelist import net.corda.serialization.internal.MutableClassWhitelist
@ -25,7 +25,7 @@ import java.util.*
/** /**
* Corda specific class resolver which enables extra customisation for the purposes of serialization using Kryo * Corda specific class resolver which enables extra customisation for the purposes of serialization using Kryo
*/ */
class CordaClassResolver(serializationContext: SerializationContext) : DefaultClassResolver() { class CordaClassResolver(serializationContext: CheckpointSerializationContext) : DefaultClassResolver() {
val whitelist: ClassWhitelist = TransientClassWhiteList(serializationContext.whitelist) val whitelist: ClassWhitelist = TransientClassWhiteList(serializationContext.whitelist)
// These classes are assignment-compatible Java equivalents of Kotlin classes. // These classes are assignment-compatible Java equivalents of Kotlin classes.

View File

@ -14,12 +14,11 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.TransactionSignature
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.UseCase.Checkpoint
import net.corda.core.serialization.SerializationContext.UseCase.Storage
import net.corda.core.serialization.SerializeAsTokenContext import net.corda.core.serialization.SerializeAsTokenContext
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.* import net.corda.core.transactions.*
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.serialization.internal.checkUseCase
import net.corda.serialization.internal.serializationContextKey import net.corda.serialization.internal.serializationContextKey
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
@ -275,16 +274,9 @@ object SignedTransactionSerializer : Serializer<SignedTransaction>() {
} }
} }
sealed class UseCaseSerializer<T>(private val allowedUseCases: EnumSet<SerializationContext.UseCase>) : Serializer<T>() {
protected fun checkUseCase() {
net.corda.serialization.internal.checkUseCase(allowedUseCases)
}
}
@ThreadSafe @ThreadSafe
object PrivateKeySerializer : UseCaseSerializer<PrivateKey>(EnumSet.of(Storage, Checkpoint)) { object PrivateKeySerializer : Serializer<PrivateKey>() {
override fun write(kryo: Kryo, output: Output, obj: PrivateKey) { override fun write(kryo: Kryo, output: Output, obj: PrivateKey) {
checkUseCase()
output.writeBytesWithLength(obj.encoded) output.writeBytesWithLength(obj.encoded)
} }

View File

@ -10,10 +10,9 @@ import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.pool.KryoPool
import com.esotericsoftware.kryo.serializers.ClosureSerializer import com.esotericsoftware.kryo.serializers.ClosureSerializer
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.*
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.internal.CheckpointSerializationScheme
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.serialization.internal.* import net.corda.serialization.internal.*
import java.security.PublicKey import java.security.PublicKey
@ -32,46 +31,30 @@ private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>()
override fun read(kryo: Kryo, input: Input, type: Class<AutoCloseable>) = throw IllegalStateException("Should not reach here!") override fun read(kryo: Kryo, input: Input, type: Class<AutoCloseable>) = throw IllegalStateException("Should not reach here!")
} }
abstract class AbstractKryoSerializationScheme : SerializationScheme { object KryoSerializationScheme : CheckpointSerializationScheme {
private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>() private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>()
protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool private fun getPool(context: CheckpointSerializationContext): KryoPool {
protected abstract fun rpcServerKryoPool(context: SerializationContext): KryoPool
// this can be overridden in derived serialization schemes
protected open val publicKeySerializer: Serializer<PublicKey> = PublicKeySerializer
private fun getPool(context: SerializationContext): KryoPool {
return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) {
when (context.useCase) { KryoPool.Builder {
SerializationContext.UseCase.Checkpoint -> val serializer = Fiber.getFiberSerializer(false) as KryoSerializer
KryoPool.Builder { val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) }
val serializer = Fiber.getFiberSerializer(false) as KryoSerializer // TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that
val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) } val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true }
// TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that serializer.kryo.apply {
val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true } field.set(this, classResolver)
serializer.kryo.apply { // don't allow overriding the public key serializer for checkpointing
field.set(this, classResolver) DefaultKryoCustomizer.customize(this)
// don't allow overriding the public key serializer for checkpointing addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector)
DefaultKryoCustomizer.customize(this) register(ClosureSerializer.Closure::class.java, CordaClosureSerializer)
addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector) classLoader = it.second
register(ClosureSerializer.Closure::class.java, CordaClosureSerializer) }
classLoader = it.second }.build()
}
}.build()
SerializationContext.UseCase.RPCClient ->
rpcClientKryoPool(context)
SerializationContext.UseCase.RPCServer ->
rpcServerKryoPool(context)
else ->
KryoPool.Builder {
DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(context)), publicKeySerializer).apply { classLoader = it.second }
}.build()
}
} }
} }
private fun <T : Any> SerializationContext.kryo(task: Kryo.() -> T): T { private fun <T : Any> CheckpointSerializationContext.kryo(task: Kryo.() -> T): T {
return getPool(this).run { kryo -> return getPool(this).run { kryo ->
kryo.context.ensureCapacity(properties.size) kryo.context.ensureCapacity(properties.size)
properties.forEach { kryo.context.put(it.key, it.value) } properties.forEach { kryo.context.put(it.key, it.value) }
@ -83,7 +66,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
} }
} }
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T { override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
val dataBytes = kryoMagic.consume(byteSequence) val dataBytes = kryoMagic.consume(byteSequence)
?: throw KryoException("Serialized bytes header does not match expected format.") ?: throw KryoException("Serialized bytes header does not match expected format.")
return context.kryo { return context.kryo {
@ -111,7 +94,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
} }
} }
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> { override fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return context.kryo { return context.kryo {
SerializedBytes(kryoOutput { SerializedBytes(kryoOutput {
kryoMagic.writeTo(this) kryoMagic.writeTo(this)
@ -131,13 +114,11 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
} }
} }
val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl( val KRYO_CHECKPOINT_CONTEXT = CheckpointSerializationContextImpl(
kryoMagic,
SerializationDefaults.javaClass.classLoader, SerializationDefaults.javaClass.classLoader,
QuasarWhitelist, QuasarWhitelist,
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.Checkpoint,
null, null,
AlwaysAcceptEncodingWhitelist AlwaysAcceptEncodingWhitelist
) )

View File

@ -1,14 +0,0 @@
package net.corda.node.serialization.kryo
import com.esotericsoftware.kryo.pool.KryoPool
import net.corda.core.serialization.SerializationContext
import net.corda.serialization.internal.CordaSerializationMagic
class KryoServerSerializationScheme : AbstractKryoSerializationScheme() {
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean {
return magic == kryoMagic && target == SerializationContext.UseCase.Checkpoint
}
override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
}

View File

@ -4,9 +4,9 @@ import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.codahale.metrics.* import com.codahale.metrics.*
import net.corda.core.internal.concurrent.thenMatch import net.corda.core.internal.concurrent.thenMatch
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.*
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.serialize import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace import net.corda.core.utilities.trace
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
@ -27,7 +27,7 @@ class ActionExecutorImpl(
private val checkpointStorage: CheckpointStorage, private val checkpointStorage: CheckpointStorage,
private val flowMessaging: FlowMessaging, private val flowMessaging: FlowMessaging,
private val stateMachineManager: StateMachineManagerInternal, private val stateMachineManager: StateMachineManagerInternal,
private val checkpointSerializationContext: SerializationContext, private val checkpointSerializationContext: CheckpointSerializationContext,
metrics: MetricRegistry metrics: MetricRegistry
) : ActionExecutor { ) : ActionExecutor {
@ -237,7 +237,7 @@ class ActionExecutorImpl(
} }
private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> { private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> {
return checkpoint.serialize(context = checkpointSerializationContext) return checkpoint.checkpointSerialize(context = checkpointSerializationContext)
} }
private fun cancelFlowTimeout(action: Action.CancelFlowTimeout) { private fun cancelFlowTimeout(action: Action.CancelFlowTimeout) {

View File

@ -12,8 +12,8 @@ import net.corda.core.cordapp.Cordapp
import net.corda.core.flows.* import net.corda.core.flows.*
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.* import net.corda.core.internal.*
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.serialize import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.core.utilities.trace import net.corda.core.utilities.trace
@ -69,7 +69,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val actionExecutor: ActionExecutor, val actionExecutor: ActionExecutor,
val stateMachine: StateMachine, val stateMachine: StateMachine,
val serviceHub: ServiceHubInternal, val serviceHub: ServiceHubInternal,
val checkpointSerializationContext: SerializationContext, val checkpointSerializationContext: CheckpointSerializationContext,
val unfinishedFibers: ReusableLatch val unfinishedFibers: ReusableLatch
) )
@ -369,7 +369,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
Event.Suspend( Event.Suspend(
ioRequest = ioRequest, ioRequest = ioRequest,
maySkipCheckpoint = skipPersistingCheckpoint, maySkipCheckpoint = skipPersistingCheckpoint,
fiber = this.serialize(context = serializationContext.value) fiber = this.checkpointSerialize(context = serializationContext.value)
) )
} catch (throwable: Throwable) { } catch (throwable: Throwable) {
Event.Error(throwable) Event.Error(throwable)

View File

@ -19,6 +19,10 @@ import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
@ -36,7 +40,7 @@ import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.injectOldProgressTracker import net.corda.node.utilities.injectOldProgressTracker
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
import net.corda.serialization.internal.SerializeAsTokenContextImpl import net.corda.serialization.internal.CheckpointSerializeAsTokenContextImpl
import net.corda.serialization.internal.withTokenContext import net.corda.serialization.internal.withTokenContext
import org.apache.activemq.artemis.utils.ReusableLatch import org.apache.activemq.artemis.utils.ReusableLatch
import rx.Observable import rx.Observable
@ -103,7 +107,7 @@ class SingleThreadedStateMachineManager(
private val transitionExecutor = makeTransitionExecutor() private val transitionExecutor = makeTransitionExecutor()
private val ourSenderUUID = serviceHub.networkService.ourSenderUUID private val ourSenderUUID = serviceHub.networkService.ourSenderUUID
private var checkpointSerializationContext: SerializationContext? = null private var checkpointSerializationContext: CheckpointSerializationContext? = null
private var actionExecutor: ActionExecutor? = null private var actionExecutor: ActionExecutor? = null
override val allStateMachines: List<FlowLogic<*>> override val allStateMachines: List<FlowLogic<*>>
@ -122,8 +126,8 @@ class SingleThreadedStateMachineManager(
override fun start(tokenizableServices: List<Any>) { override fun start(tokenizableServices: List<Any>) {
checkQuasarJavaAgentPresence() checkQuasarJavaAgentPresence()
val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
) )
this.checkpointSerializationContext = checkpointSerializationContext this.checkpointSerializationContext = checkpointSerializationContext
this.actionExecutor = makeActionExecutor(checkpointSerializationContext) this.actionExecutor = makeActionExecutor(checkpointSerializationContext)
@ -531,7 +535,7 @@ class SingleThreadedStateMachineManager(
val resultFuture = openFuture<Any?>() val resultFuture = openFuture<Any?>()
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture)) flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
flowLogic.stateMachine = flowStateMachineImpl flowLogic.stateMachine = flowStateMachineImpl
val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!) val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext!!)
val flowCorDappVersion = createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion) val flowCorDappVersion = createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion)
@ -613,7 +617,7 @@ class SingleThreadedStateMachineManager(
private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes<Checkpoint>): Checkpoint? { private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes<Checkpoint>): Checkpoint? {
return try { return try {
serializedCheckpoint.deserialize(context = checkpointSerializationContext!!) serializedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext!!)
} catch (exception: Throwable) { } catch (exception: Throwable) {
logger.error("Encountered unrestorable checkpoint!", exception) logger.error("Encountered unrestorable checkpoint!", exception)
null null
@ -658,7 +662,7 @@ class SingleThreadedStateMachineManager(
val resultFuture = openFuture<Any?>() val resultFuture = openFuture<Any?>()
val fiber = when (flowState) { val fiber = when (flowState) {
is FlowState.Unstarted -> { is FlowState.Unstarted -> {
val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!) val logic = flowState.frozenFlowLogic.checkpointDeserialize(context = checkpointSerializationContext!!)
val state = StateMachineState( val state = StateMachineState(
checkpoint = checkpoint, checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
@ -677,7 +681,7 @@ class SingleThreadedStateMachineManager(
fiber fiber
} }
is FlowState.Started -> { is FlowState.Started -> {
val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!) val fiber = flowState.frozenFiber.checkpointDeserialize(context = checkpointSerializationContext!!)
val state = StateMachineState( val state = StateMachineState(
checkpoint = checkpoint, checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(), pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
@ -742,7 +746,7 @@ class SingleThreadedStateMachineManager(
} }
} }
private fun makeActionExecutor(checkpointSerializationContext: SerializationContext): ActionExecutor { private fun makeActionExecutor(checkpointSerializationContext: CheckpointSerializationContext): ActionExecutor {
return ActionExecutorImpl( return ActionExecutorImpl(
serviceHub, serviceHub,
checkpointStorage, checkpointStorage,

View File

@ -2,9 +2,9 @@ package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.*
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.deserialize import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.ActionExecutor import net.corda.node.services.statemachine.ActionExecutor
import net.corda.node.services.statemachine.Event import net.corda.node.services.statemachine.Event
@ -68,7 +68,7 @@ class FiberDeserializationChecker {
private val jobQueue = LinkedBlockingQueue<Job>() private val jobQueue = LinkedBlockingQueue<Job>()
private var foundUnrestorableFibers: Boolean = false private var foundUnrestorableFibers: Boolean = false
fun start(checkpointSerializationContext: SerializationContext) { fun start(checkpointSerializationContext: CheckpointSerializationContext) {
require(checkerThread == null) require(checkerThread == null)
checkerThread = thread(name = "FiberDeserializationChecker") { checkerThread = thread(name = "FiberDeserializationChecker") {
while (true) { while (true) {
@ -76,7 +76,7 @@ class FiberDeserializationChecker {
when (job) { when (job) {
is Job.Check -> { is Job.Check -> {
try { try {
job.serializedFiber.deserialize(context = checkpointSerializationContext) job.serializedFiber.checkpointDeserialize(context = checkpointSerializationContext)
} catch (throwable: Throwable) { } catch (throwable: Throwable) {
log.error("Encountered unrestorable checkpoint!", throwable) log.error("Encountered unrestorable checkpoint!", throwable)
foundUnrestorableFibers = true foundUnrestorableFibers = true

View File

@ -20,10 +20,10 @@ import net.corda.core.flows.StateConsumptionDetails
import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.notary.isConsumedByTheSameTx import net.corda.core.internal.notary.isConsumedByTheSameTx
import net.corda.core.internal.notary.validateTimeWindow import net.corda.core.internal.notary.validateTimeWindow
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.*
import net.corda.core.serialization.SerializationFactory import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.deserialize import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.serialize import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
@ -200,11 +200,11 @@ class RaftTransactionCommitLog<E, EK>(
} }
class CordaKryoSerializer<T : Any> : TypeSerializer<T> { class CordaKryoSerializer<T : Any> : TypeSerializer<T> {
private val context = SerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY) private val context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY)
private val factory = SerializationFactory.defaultFactory private val factory = CheckpointSerializationFactory.defaultFactory
override fun write(obj: T, buffer: BufferOutput<*>, serializer: Serializer) { override fun write(obj: T, buffer: BufferOutput<*>, serializer: Serializer) {
val serialized = obj.serialize(context = context) val serialized = obj.checkpointSerialize(context = context)
buffer.writeInt(serialized.size) buffer.writeInt(serialized.size)
buffer.write(serialized.bytes) buffer.write(serialized.bytes)
} }

View File

@ -5,7 +5,6 @@ import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.KryoSerializable import com.esotericsoftware.kryo.KryoSerializable
import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.primitives.Ints import com.google.common.primitives.Ints
import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
@ -13,6 +12,10 @@ import net.corda.core.contracts.PrivacySalt
import net.corda.core.crypto.* import net.corda.core.crypto.*
import net.corda.core.internal.FetchDataFlow import net.corda.core.internal.FetchDataFlow
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.sequence import net.corda.core.utilities.sequence
@ -36,16 +39,6 @@ import java.util.*
import kotlin.collections.ArrayList import kotlin.collections.ArrayList
import kotlin.test.* import kotlin.test.*
class TestScheme : AbstractKryoSerializationScheme() {
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean {
return magic == kryoMagic && target != SerializationContext.UseCase.RPCClient
}
override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
}
@RunWith(Parameterized::class) @RunWith(Parameterized::class)
class KryoTests(private val compression: CordaSerializationEncoding?) { class KryoTests(private val compression: CordaSerializationEncoding?) {
companion object { companion object {
@ -55,18 +48,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values() fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
} }
private lateinit var factory: SerializationFactory private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: SerializationContext private lateinit var context: CheckpointSerializationContext
@Before @Before
fun setup() { fun setup() {
factory = SerializationFactoryImpl().apply { registerScheme(TestScheme()) } factory = CheckpointSerializationFactory(KryoSerializationScheme)
context = SerializationContextImpl(kryoMagic, context = CheckpointSerializationContextImpl(
javaClass.classLoader, javaClass.classLoader,
AllWhitelist, AllWhitelist,
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.Storage,
compression, compression,
rigorousMock<EncodingWhitelist>().also { rigorousMock<EncodingWhitelist>().also {
if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression) if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression)
@ -77,15 +69,15 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `simple data class`() { fun `simple data class`() {
val birthday = Instant.parse("1984-04-17T00:30:00.00Z") val birthday = Instant.parse("1984-04-17T00:30:00.00Z")
val mike = Person("mike", birthday) val mike = Person("mike", birthday)
val bits = mike.serialize(factory, context) val bits = mike.checkpointSerialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(Person("mike", birthday)) assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("mike", birthday))
} }
@Test @Test
fun `null values`() { fun `null values`() {
val bob = Person("bob", null) val bob = Person("bob", null)
val bits = bob.serialize(factory, context) val bits = bob.checkpointSerialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(Person("bob", null)) assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("bob", null))
} }
@Test @Test
@ -93,10 +85,10 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val noReferencesContext = context.withoutReferences() val noReferencesContext = context.withoutReferences()
val obj : ByteSequence = Ints.toByteArray(0x01234567).sequence() val obj : ByteSequence = Ints.toByteArray(0x01234567).sequence()
val originalList : ArrayList<ByteSequence> = ArrayList<ByteSequence>().apply { this += obj } val originalList : ArrayList<ByteSequence> = ArrayList<ByteSequence>().apply { this += obj }
val deserialisedList = originalList.serialize(factory, noReferencesContext).deserialize(factory, noReferencesContext) val deserialisedList = originalList.checkpointSerialize(factory, noReferencesContext).checkpointDeserialize(factory, noReferencesContext)
originalList += obj originalList += obj
deserialisedList += obj deserialisedList += obj
assertThat(deserialisedList.serialize(factory, noReferencesContext)).isEqualTo(originalList.serialize(factory, noReferencesContext)) assertThat(deserialisedList.checkpointSerialize(factory, noReferencesContext)).isEqualTo(originalList.checkpointSerialize(factory, noReferencesContext))
} }
@Test @Test
@ -113,14 +105,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
this += instant this += instant
this += instant this += instant
} }
assertThat(listWithSameInstances.serialize(factory, noReferencesContext)).isEqualTo(listWithCopies.serialize(factory, noReferencesContext)) assertThat(listWithSameInstances.checkpointSerialize(factory, noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(factory, noReferencesContext))
} }
@Test @Test
fun `cyclic object graph`() { fun `cyclic object graph`() {
val cyclic = Cyclic(3) val cyclic = Cyclic(3)
val bits = cyclic.serialize(factory, context) val bits = cyclic.checkpointSerialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(cyclic) assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(cyclic)
} }
@Test @Test
@ -132,7 +124,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
signature.verify(bitsToSign) signature.verify(bitsToSign)
assertThatThrownBy { signature.verify(wrongBits) } assertThatThrownBy { signature.verify(wrongBits) }
val deserialisedKeyPair = keyPair.serialize(factory, context).deserialize(factory, context) val deserialisedKeyPair = keyPair.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign) val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign)
deserialisedSignature.verify(bitsToSign) deserialisedSignature.verify(bitsToSign)
assertThatThrownBy { deserialisedSignature.verify(wrongBits) } assertThatThrownBy { deserialisedSignature.verify(wrongBits) }
@ -140,28 +132,28 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test @Test
fun `write and read Kotlin object singleton`() { fun `write and read Kotlin object singleton`() {
val serialised = TestSingleton.serialize(factory, context) val serialised = TestSingleton.checkpointSerialize(factory, context)
val deserialised = serialised.deserialize(factory, context) val deserialised = serialised.checkpointDeserialize(factory, context)
assertThat(deserialised).isSameAs(TestSingleton) assertThat(deserialised).isSameAs(TestSingleton)
} }
@Test @Test
fun `check Kotlin EmptyList can be serialised`() { fun `check Kotlin EmptyList can be serialised`() {
val deserialisedList: List<Int> = emptyList<Int>().serialize(factory, context).deserialize(factory, context) val deserialisedList: List<Int> = emptyList<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(0, deserialisedList.size) assertEquals(0, deserialisedList.size)
assertEquals<Any>(Collections.emptyList<Int>().javaClass, deserialisedList.javaClass) assertEquals<Any>(Collections.emptyList<Int>().javaClass, deserialisedList.javaClass)
} }
@Test @Test
fun `check Kotlin EmptySet can be serialised`() { fun `check Kotlin EmptySet can be serialised`() {
val deserialisedSet: Set<Int> = emptySet<Int>().serialize(factory, context).deserialize(factory, context) val deserialisedSet: Set<Int> = emptySet<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(0, deserialisedSet.size) assertEquals(0, deserialisedSet.size)
assertEquals<Any>(Collections.emptySet<Int>().javaClass, deserialisedSet.javaClass) assertEquals<Any>(Collections.emptySet<Int>().javaClass, deserialisedSet.javaClass)
} }
@Test @Test
fun `check Kotlin EmptyMap can be serialised`() { fun `check Kotlin EmptyMap can be serialised`() {
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().serialize(factory, context).deserialize(factory, context) val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(0, deserialisedMap.size) assertEquals(0, deserialisedMap.size)
assertEquals<Any>(Collections.emptyMap<Int, Int>().javaClass, deserialisedMap.javaClass) assertEquals<Any>(Collections.emptyMap<Int, Int>().javaClass, deserialisedMap.javaClass)
} }
@ -169,7 +161,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test @Test
fun `InputStream serialisation`() { fun `InputStream serialisation`() {
val rubbish = ByteArray(12345) { (it * it * 0.12345).toByte() } val rubbish = ByteArray(12345) { (it * it * 0.12345).toByte() }
val readRubbishStream: InputStream = rubbish.inputStream().serialize(factory, context).deserialize(factory, context) val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
for (i in 0..12344) { for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte()) assertEquals(rubbish[i], readRubbishStream.read().toByte())
} }
@ -179,7 +171,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test @Test
fun `InputStream serialisation does not write trailing garbage`() { fun `InputStream serialisation does not write trailing garbage`() {
val byteArrays = listOf("123", "456").map { it.toByteArray() } val byteArrays = listOf("123", "456").map { it.toByteArray() }
val streams = byteArrays.map { it.inputStream() }.serialize(factory, context).deserialize(factory, context).iterator() val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(factory, context).checkpointDeserialize(factory, context).iterator()
byteArrays.forEach { assertArrayEquals(it, streams.next().readBytes()) } byteArrays.forEach { assertArrayEquals(it, streams.next().readBytes()) }
assertFalse(streams.hasNext()) assertFalse(streams.hasNext())
} }
@ -190,16 +182,16 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val testBytes = testString.toByteArray() val testBytes = testString.toByteArray()
val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID)) val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID))
val serializedMetaData = meta.serialize(factory, context).bytes val serializedMetaData = meta.checkpointSerialize(factory, context).bytes
val meta2 = serializedMetaData.deserialize<SignableData>(factory, context) val meta2 = serializedMetaData.checkpointDeserialize<SignableData>(factory, context)
assertEquals(meta2, meta) assertEquals(meta2, meta)
} }
@Test @Test
fun `serialize - deserialize Logger`() { fun `serialize - deserialize Logger`() {
val storageContext: SerializationContext = context // TODO: make it storage context val storageContext: CheckpointSerializationContext = context
val logger = LoggerFactory.getLogger("aName") val logger = LoggerFactory.getLogger("aName")
val logger2 = logger.serialize(factory, storageContext).deserialize(factory, storageContext) val logger2 = logger.checkpointSerialize(factory, storageContext).checkpointDeserialize(factory, storageContext)
assertEquals(logger.name, logger2.name) assertEquals(logger.name, logger2.name)
assertTrue(logger === logger2) assertTrue(logger === logger2)
} }
@ -211,7 +203,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
SecureHash.sha256(rubbish), SecureHash.sha256(rubbish),
rubbish.size, rubbish.size,
rubbish.inputStream() rubbish.inputStream()
).serialize(factory, context).deserialize(factory, context) ).checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
for (i in 0..12344) { for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte()) assertEquals(rubbish[i], readRubbishStream.read().toByte())
} }
@ -238,8 +230,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32 31, 32
)) ))
val serializedBytes = expected.serialize(factory, context) val serializedBytes = expected.checkpointSerialize(factory, context)
val actual = serializedBytes.deserialize(factory, context) val actual = serializedBytes.checkpointDeserialize(factory, context)
assertEquals(expected, actual) assertEquals(expected, actual)
} }
@ -286,15 +278,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
} }
} }
Tmp() Tmp()
val factory = SerializationFactoryImpl().apply { registerScheme(TestScheme()) } val factory = CheckpointSerializationFactory(KryoSerializationScheme)
val context = SerializationContextImpl(kryoMagic, val context = CheckpointSerializationContextImpl(
javaClass.classLoader, javaClass.classLoader,
AllWhitelist, AllWhitelist,
emptyMap(), emptyMap(),
true, true,
SerializationContext.UseCase.P2P,
null) null)
pt.serialize(factory, context) pt.checkpointSerialize(factory, context)
} }
@Test @Test
@ -302,7 +293,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val exception = IllegalArgumentException("fooBar") val exception = IllegalArgumentException("fooBar")
val toBeSuppressedOnSenderSide = IllegalStateException("bazz1") val toBeSuppressedOnSenderSide = IllegalStateException("bazz1")
exception.addSuppressed(toBeSuppressedOnSenderSide) exception.addSuppressed(toBeSuppressedOnSenderSide)
val exception2 = exception.serialize(factory, context).deserialize(factory, context) val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(exception.message, exception2.message) assertEquals(exception.message, exception2.message)
assertEquals(1, exception2.suppressed.size) assertEquals(1, exception2.suppressed.size)
@ -317,7 +308,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test @Test
fun `serialize - deserialize Exception no suppressed`() { fun `serialize - deserialize Exception no suppressed`() {
val exception = IllegalArgumentException("fooBar") val exception = IllegalArgumentException("fooBar")
val exception2 = exception.serialize(factory, context).deserialize(factory, context) val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(exception.message, exception2.message) assertEquals(exception.message, exception2.message)
assertEquals(0, exception2.suppressed.size) assertEquals(0, exception2.suppressed.size)
@ -331,7 +322,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `serialize - deserialize HashNotFound`() { fun `serialize - deserialize HashNotFound`() {
val randomHash = SecureHash.randomSHA256() val randomHash = SecureHash.randomSHA256()
val exception = FetchDataFlow.HashNotFound(randomHash) val exception = FetchDataFlow.HashNotFound(randomHash)
val exception2 = exception.serialize(factory, context).deserialize(factory, context) val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(randomHash, exception2.requested) assertEquals(randomHash, exception2.requested)
} }
@ -339,17 +330,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `compression has the desired effect`() { fun `compression has the desired effect`() {
compression ?: return compression ?: return
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it } val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
val compressed = data.serialize(factory, context) val compressed = data.checkpointSerialize(factory, context)
assertEquals(.5, compressed.size.toDouble() / data.size, .03) assertEquals(.5, compressed.size.toDouble() / data.size, .03)
assertArrayEquals(data, compressed.deserialize(factory, context)) assertArrayEquals(data, compressed.checkpointDeserialize(factory, context))
} }
@Test @Test
fun `a particular encoding can be banned for deserialization`() { fun `a particular encoding can be banned for deserialization`() {
compression ?: return compression ?: return
doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression) doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression)
val compressed = "whatever".serialize(factory, context) val compressed = "whatever".checkpointSerialize(factory, context)
catchThrowable { compressed.deserialize(factory, context) }.run { catchThrowable { compressed.checkpointDeserialize(factory, context) }.run {
assertSame<Any>(KryoException::class.java, javaClass) assertSame<Any>(KryoException::class.java, javaClass)
assertEquals(encodingNotPermittedFormat.format(compression), message) assertEquals(encodingNotPermittedFormat.format(compression), message)
} }
@ -360,8 +351,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
class Holder(val holder: ByteArray) class Holder(val holder: ByteArray)
val obj = Holder(ByteArray(20000)) val obj = Holder(ByteArray(20000))
val uncompressedSize = obj.serialize(factory, context.withEncoding(null)).size val uncompressedSize = obj.checkpointSerialize(factory, context.withEncoding(null)).size
val compressedSize = obj.serialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size val compressedSize = obj.checkpointSerialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
// If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade. // If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade.
assertEquals(20222, uncompressedSize) assertEquals(20222, uncompressedSize)
assertEquals(1111, compressedSize) assertEquals(1111, compressedSize)

View File

@ -3,9 +3,9 @@ package net.corda.node.services.persistence
import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.node.internal.CheckpointIncompatibleException import net.corda.node.internal.CheckpointIncompatibleException
import net.corda.node.internal.CheckpointVerifier import net.corda.node.internal.CheckpointVerifier
import net.corda.node.internal.configureDatabase import net.corda.node.internal.configureDatabase
@ -189,9 +189,9 @@ class DBCheckpointStorageTests {
val logic: FlowLogic<*> = object : FlowLogic<Unit>() { val logic: FlowLogic<*> = object : FlowLogic<Unit>() {
override fun call() {} override fun call() {}
} }
val frozenLogic = logic.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT) val frozenLogic = logic.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, ALICE, SubFlowVersion.CoreFlow(version)).getOrThrow() val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, ALICE, SubFlowVersion.CoreFlow(version)).getOrThrow()
return id to checkpoint.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT) return id to checkpoint.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
} }
} }

View File

@ -0,0 +1,49 @@
package net.corda.serialization.internal
import net.corda.core.KeepForDJVM
import net.corda.core.crypto.SecureHash
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
@KeepForDJVM
data class CheckpointSerializationContextImpl @JvmOverloads constructor(
override val deserializationClassLoader: ClassLoader,
override val whitelist: ClassWhitelist,
override val properties: Map<Any, Any>,
override val objectReferencesEnabled: Boolean,
override val encoding: SerializationEncoding?,
override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) : CheckpointSerializationContext {
private val builder = AttachmentsClassLoaderBuilder(properties, deserializationClassLoader)
/**
* {@inheritDoc}
*
* We need to cache the AttachmentClassLoaders to avoid too many contexts, since the class loader is part of cache key for the context.
*/
override fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): CheckpointSerializationContext {
properties[attachmentsClassLoaderEnabledPropertyName] as? Boolean == true || return this
val classLoader = builder.build(attachmentHashes) ?: return this
return withClassLoader(classLoader)
}
override fun withProperty(property: Any, value: Any): CheckpointSerializationContext {
return copy(properties = properties + (property to value))
}
override fun withoutReferences(): CheckpointSerializationContext {
return copy(objectReferencesEnabled = false)
}
override fun withClassLoader(classLoader: ClassLoader): CheckpointSerializationContext {
return copy(deserializationClassLoader = classLoader)
}
override fun withWhitelisted(clazz: Class<*>): CheckpointSerializationContext {
return copy(whitelist = object : ClassWhitelist {
override fun hasListed(type: Class<*>): Boolean = whitelist.hasListed(type) || type.name == clazz.name
})
}
override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding)
override fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist) = copy(encodingWhitelist = encodingWhitelist)
}

View File

@ -3,14 +3,14 @@ package net.corda.serialization.internal
import net.corda.core.DeleteForDJVM import net.corda.core.DeleteForDJVM
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.*
import net.corda.core.serialization.SerializationFactory import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.SerializeAsTokenContext
val serializationContextKey = SerializeAsTokenContext::class.java val serializationContextKey = SerializeAsTokenContext::class.java
fun SerializationContext.withTokenContext(serializationContext: SerializeAsTokenContext): SerializationContext = this.withProperty(serializationContextKey, serializationContext) fun SerializationContext.withTokenContext(serializationContext: SerializeAsTokenContext): SerializationContext = this.withProperty(serializationContextKey, serializationContext)
fun CheckpointSerializationContext.withTokenContext(serializationContext: SerializeAsTokenContext): CheckpointSerializationContext = this.withProperty(serializationContextKey, serializationContext)
/** /**
* A context for mapping SerializationTokens to/from SerializeAsTokens. * A context for mapping SerializationTokens to/from SerializeAsTokens.
@ -55,6 +55,53 @@ class SerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: Ser
} }
} }
override fun getSingleton(className: String) = classNameToSingleton[className]
?: throw IllegalStateException("Unable to find tokenized instance of $className in context $this")
}
/**
* A context for mapping SerializationTokens to/from SerializeAsTokens.
*
* A context is initialised with an object containing all the instances of [SerializeAsToken] to eagerly register all the tokens.
* In our case this can be the [ServiceHub].
*
* Then it is a case of using the companion object methods on [SerializeAsTokenSerializer] to set and clear context as necessary
* when serializing to enable/disable tokenization.
*/
@DeleteForDJVM
class CheckpointSerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: SerializeAsTokenContext.() -> Unit) : SerializeAsTokenContext {
constructor(toBeTokenized: Any, serializationFactory: CheckpointSerializationFactory, context: CheckpointSerializationContext, serviceHub: ServiceHub) : this(serviceHub, {
serializationFactory.serialize(toBeTokenized, context.withTokenContext(this))
})
private val classNameToSingleton = mutableMapOf<String, SerializeAsToken>()
private var readOnly = false
init {
/**
* Go ahead and eagerly serialize the object to register all of the tokens in the context.
*
* This results in the toToken() method getting called for any [SingletonSerializeAsToken] instances which
* are encountered in the object graph as they are serialized and will therefore register the token to
* object mapping for those instances. We then immediately set the readOnly flag to stop further adhoc or
* accidental registrations from occuring as these could not be deserialized in a deserialization-first
* scenario if they are not part of this iniital context construction serialization.
*/
init(this)
readOnly = true
}
override fun putSingleton(toBeTokenized: SerializeAsToken) {
val className = toBeTokenized.javaClass.name
if (className !in classNameToSingleton) {
// Only allowable if we are in SerializeAsTokenContext init (readOnly == false)
if (readOnly) {
throw UnsupportedOperationException("Attempt to write token for lazy registered $className. All tokens should be registered during context construction.")
}
classNameToSingleton[className] = toBeTokenized
}
}
override fun getSingleton(className: String) = classNameToSingleton[className] override fun getSingleton(className: String) = classNameToSingleton[className]
?: throw IllegalStateException("Unable to find tokenized instance of $className in context $this") ?: throw IllegalStateException("Unable to find tokenized instance of $className in context $this")
} }

View File

@ -13,3 +13,11 @@ fun checkUseCase(allowedUseCases: EnumSet<SerializationContext.UseCase>) {
throw IllegalStateException("UseCase '${currentContext.useCase}' is not within '$allowedUseCases'") throw IllegalStateException("UseCase '${currentContext.useCase}' is not within '$allowedUseCases'")
} }
} }
fun checkUseCase(allowedUseCase: SerializationContext.UseCase) {
val currentContext: SerializationContext = SerializationFactory.currentFactory?.currentContext
?: throw IllegalStateException("Current context is not set")
if (allowedUseCase != currentContext.useCase) {
throw IllegalStateException("UseCase '${currentContext.useCase}' is not '$allowedUseCase'")
}
}

View File

@ -163,8 +163,6 @@ abstract class AbstractAMQPSerializationScheme(
return synchronized(serializerFactoriesForContexts) { return synchronized(serializerFactoriesForContexts) {
serializerFactoriesForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { serializerFactoriesForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) {
when (context.useCase) { when (context.useCase) {
SerializationContext.UseCase.Checkpoint ->
throw IllegalStateException("AMQP should not be used for checkpoint serialization.")
SerializationContext.UseCase.RPCClient -> SerializationContext.UseCase.RPCClient ->
rpcClientSerializerFactory(context) rpcClientSerializerFactory(context)
SerializationContext.UseCase.RPCServer -> SerializationContext.UseCase.RPCServer ->

View File

@ -2,7 +2,6 @@ package net.corda.serialization.internal.amqp.custom
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.UseCase.Checkpoint
import net.corda.core.serialization.SerializationContext.UseCase.Storage import net.corda.core.serialization.SerializationContext.UseCase.Storage
import net.corda.serialization.internal.amqp.* import net.corda.serialization.internal.amqp.*
import net.corda.serialization.internal.checkUseCase import net.corda.serialization.internal.checkUseCase
@ -13,14 +12,12 @@ import java.util.*
object PrivateKeySerializer : CustomSerializer.Implements<PrivateKey>(PrivateKey::class.java) { object PrivateKeySerializer : CustomSerializer.Implements<PrivateKey>(PrivateKey::class.java) {
private val allowedUseCases = EnumSet.of(Storage, Checkpoint)
override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, descriptor, emptyList()))) override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, descriptor, emptyList())))
override fun writeDescribedObject(obj: PrivateKey, data: Data, type: Type, output: SerializationOutput, override fun writeDescribedObject(obj: PrivateKey, data: Data, type: Type, output: SerializationOutput,
context: SerializationContext context: SerializationContext
) { ) {
checkUseCase(allowedUseCases) checkUseCase(Storage)
output.writeObject(obj.encoded, data, clazz, context) output.writeObject(obj.encoded, data, clazz, context)
} }

View File

@ -4,7 +4,6 @@ import com.google.common.collect.Maps;
import net.corda.core.serialization.SerializationContext; import net.corda.core.serialization.SerializationContext;
import net.corda.core.serialization.SerializationFactory; import net.corda.core.serialization.SerializationFactory;
import net.corda.core.serialization.SerializedBytes; import net.corda.core.serialization.SerializedBytes;
import net.corda.serialization.internal.amqp.AMQPNotSerializableException;
import net.corda.serialization.internal.amqp.SchemaKt; import net.corda.serialization.internal.amqp.SchemaKt;
import net.corda.testing.core.SerializationEnvironmentRule; import net.corda.testing.core.SerializationEnvironmentRule;
import org.junit.Before; import org.junit.Before;
@ -20,8 +19,10 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.ThrowableAssert.catchThrowable; import static org.assertj.core.api.ThrowableAssert.catchThrowable;
public final class ForbiddenLambdaSerializationTests { public final class ForbiddenLambdaSerializationTests {
private EnumSet<SerializationContext.UseCase> contexts = EnumSet.complementOf( private EnumSet<SerializationContext.UseCase> contexts = EnumSet.complementOf(
EnumSet.of(SerializationContext.UseCase.Checkpoint, SerializationContext.UseCase.Testing)); EnumSet.of(SerializationContext.UseCase.Testing));
@Rule @Rule
public final SerializationEnvironmentRule testSerialization = new SerializationEnvironmentRule(); public final SerializationEnvironmentRule testSerialization = new SerializationEnvironmentRule();
private SerializationFactory factory; private SerializationFactory factory;

View File

@ -1,11 +1,11 @@
package net.corda.serialization.internal; package net.corda.serialization.internal;
import net.corda.core.serialization.SerializationContext; import net.corda.core.serialization.*;
import net.corda.core.serialization.SerializationFactory; import net.corda.core.serialization.internal.CheckpointSerializationContext;
import net.corda.core.serialization.SerializedBytes; import net.corda.core.serialization.internal.CheckpointSerializationFactory;
import net.corda.node.serialization.kryo.CordaClosureSerializer; import net.corda.node.serialization.kryo.CordaClosureSerializer;
import net.corda.node.serialization.kryo.KryoSerializationSchemeKt;
import net.corda.testing.core.SerializationEnvironmentRule; import net.corda.testing.core.SerializationEnvironmentRule;
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
@ -18,21 +18,22 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.ThrowableAssert.catchThrowable; import static org.assertj.core.api.ThrowableAssert.catchThrowable;
public final class LambdaCheckpointSerializationTest { public final class LambdaCheckpointSerializationTest {
@Rule @Rule
public final SerializationEnvironmentRule testSerialization = new SerializationEnvironmentRule(); public final CheckpointSerializationEnvironmentRule testCheckpointSerialization =
private SerializationFactory factory; new CheckpointSerializationEnvironmentRule();
private SerializationContext context;
private CheckpointSerializationFactory factory;
private CheckpointSerializationContext context;
@Before @Before
public void setup() { public void setup() {
factory = testSerialization.getSerializationFactory(); factory = testCheckpointSerialization.getCheckpointSerializationFactory();
context = new SerializationContextImpl( context = new CheckpointSerializationContextImpl(
KryoSerializationSchemeKt.getKryoMagic(),
getClass().getClassLoader(), getClass().getClassLoader(),
AllWhitelist.INSTANCE, AllWhitelist.INSTANCE,
Collections.emptyMap(), Collections.emptyMap(),
true, true,
SerializationContext.UseCase.Checkpoint,
null null
); );
} }

View File

@ -3,8 +3,13 @@ package net.corda.serialization.internal
import net.corda.core.contracts.ContractAttachment import net.corda.core.contracts.ContractAttachment
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices
import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY
@ -17,28 +22,29 @@ import org.junit.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
class ContractAttachmentSerializerTest { class ContractAttachmentSerializerTest {
@Rule @Rule
@JvmField @JvmField
val testSerialization = SerializationEnvironmentRule() val testCheckpointSerialization = CheckpointSerializationEnvironmentRule()
private lateinit var factory: SerializationFactory private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: SerializationContext private lateinit var context: CheckpointSerializationContext
private lateinit var contextWithToken: SerializationContext private lateinit var contextWithToken: CheckpointSerializationContext
private val mockServices = MockServices(emptyList(), CordaX500Name("MegaCorp", "London", "GB"), rigorousMock()) private val mockServices = MockServices(emptyList(), CordaX500Name("MegaCorp", "London", "GB"), rigorousMock())
@Before @Before
fun setup() { fun setup() {
factory = testSerialization.serializationFactory factory = testCheckpointSerialization.checkpointSerializationFactory
context = testSerialization.checkpointContext context = testCheckpointSerialization.checkpointSerializationContext
contextWithToken = context.withTokenContext(SerializeAsTokenContextImpl(Any(), factory, context, mockServices)) contextWithToken = context.withTokenContext(CheckpointSerializeAsTokenContextImpl(Any(), factory, context, mockServices))
} }
@Test @Test
fun `write contract attachment and read it back`() { fun `write contract attachment and read it back`() {
val contractAttachment = ContractAttachment(GeneratedAttachment(EMPTY_BYTE_ARRAY), DummyContract.PROGRAM_ID) val contractAttachment = ContractAttachment(GeneratedAttachment(EMPTY_BYTE_ARRAY), DummyContract.PROGRAM_ID)
// no token context so will serialize the whole attachment // no token context so will serialize the whole attachment
val serialized = contractAttachment.serialize(factory, context) val serialized = contractAttachment.checkpointSerialize(factory, context)
val deserialized = serialized.deserialize(factory, context) val deserialized = serialized.checkpointDeserialize(factory, context)
assertEquals(contractAttachment.id, deserialized.attachment.id) assertEquals(contractAttachment.id, deserialized.attachment.id)
assertEquals(contractAttachment.contract, deserialized.contract) assertEquals(contractAttachment.contract, deserialized.contract)
@ -53,8 +59,8 @@ class ContractAttachmentSerializerTest {
mockServices.attachments.importAttachment(attachment.open(), "test", null) mockServices.attachments.importAttachment(attachment.open(), "test", null)
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.serialize(factory, contextWithToken) val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
val deserialized = serialized.deserialize(factory, contextWithToken) val deserialized = serialized.checkpointDeserialize(factory, contextWithToken)
assertEquals(contractAttachment.id, deserialized.attachment.id) assertEquals(contractAttachment.id, deserialized.attachment.id)
assertEquals(contractAttachment.contract, deserialized.contract) assertEquals(contractAttachment.contract, deserialized.contract)
@ -70,7 +76,7 @@ class ContractAttachmentSerializerTest {
mockServices.attachments.importAttachment(attachment.open(), "test", null) mockServices.attachments.importAttachment(attachment.open(), "test", null)
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.serialize(factory, contextWithToken) val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
assertThat(serialized.size).isLessThan(largeAttachmentSize) assertThat(serialized.size).isLessThan(largeAttachmentSize)
} }
@ -82,8 +88,8 @@ class ContractAttachmentSerializerTest {
// don't importAttachment in mockService // don't importAttachment in mockService
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.serialize(factory, contextWithToken) val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
val deserialized = serialized.deserialize(factory, contextWithToken) val deserialized = serialized.checkpointDeserialize(factory, contextWithToken)
assertThatThrownBy { deserialized.attachment.open() }.isInstanceOf(MissingAttachmentsException::class.java) assertThatThrownBy { deserialized.attachment.open() }.isInstanceOf(MissingAttachmentsException::class.java)
} }
@ -94,8 +100,8 @@ class ContractAttachmentSerializerTest {
// don't importAttachment in mockService // don't importAttachment in mockService
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID) val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.serialize(factory, contextWithToken) val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
serialized.deserialize(factory, contextWithToken) serialized.checkpointDeserialize(factory, contextWithToken)
// MissingAttachmentsException thrown if we try to open attachment // MissingAttachmentsException thrown if we try to open attachment
} }

View File

@ -11,12 +11,11 @@ import com.nhaarman.mockito_kotlin.verify
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.internal.DEPLOYED_CORDAPP_UPLOADER import net.corda.core.internal.DEPLOYED_CORDAPP_UPLOADER
import net.corda.core.node.services.AttachmentStorage import net.corda.core.node.services.AttachmentStorage
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationContext
import net.corda.node.serialization.kryo.CordaClassResolver import net.corda.node.serialization.kryo.CordaClassResolver
import net.corda.node.serialization.kryo.CordaKryo import net.corda.node.serialization.kryo.CordaKryo
import net.corda.node.serialization.kryo.kryoMagic
import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.rigorousMock
import net.corda.testing.services.MockAttachmentStorage import net.corda.testing.services.MockAttachmentStorage
import org.junit.Rule import org.junit.Rule
@ -115,8 +114,8 @@ class CordaClassResolverTests {
val emptyMapClass = mapOf<Any, Any>().javaClass val emptyMapClass = mapOf<Any, Any>().javaClass
} }
private val emptyWhitelistContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P, null) private val emptyWhitelistContext: CheckpointSerializationContext = CheckpointSerializationContextImpl(this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, null)
private val allButBlacklistedContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P, null) private val allButBlacklistedContext: CheckpointSerializationContext = CheckpointSerializationContextImpl(this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, null)
@Test @Test
fun `Annotation on enum works for specialised entries`() { fun `Annotation on enum works for specialised entries`() {
CordaClassResolver(emptyWhitelistContext).getRegistration(Foo.Bar::class.java) CordaClassResolver(emptyWhitelistContext).getRegistration(Foo.Bar::class.java)

View File

@ -3,6 +3,8 @@ package net.corda.serialization.internal
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.serialization.SerializationContext.UseCase.* import net.corda.core.serialization.SerializationContext.UseCase.*
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.SerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
@ -33,13 +35,13 @@ class PrivateKeySerializationTest(private val privateKey: PrivateKey, private va
@Test @Test
fun `passed with expected UseCases`() { fun `passed with expected UseCases`() {
assertTrue { privateKey.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes.isNotEmpty() } assertTrue { privateKey.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes.isNotEmpty() }
assertTrue { privateKey.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT).bytes.isNotEmpty() } assertTrue { privateKey.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT).bytes.isNotEmpty() }
} }
@Test @Test
fun `failed with wrong UseCase`() { fun `failed with wrong UseCase`() {
assertThatThrownBy { privateKey.serialize(context = SerializationDefaults.P2P_CONTEXT) } assertThatThrownBy { privateKey.serialize(context = SerializationDefaults.P2P_CONTEXT) }
.isInstanceOf(IllegalStateException::class.java) .isInstanceOf(IllegalStateException::class.java)
.hasMessageContaining("UseCase '$P2P' is not within") .hasMessageContaining("UseCase '$P2P' is not 'Storage")
} }
} }

View File

@ -4,6 +4,10 @@ import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.node.serialization.kryo.CordaClassResolver import net.corda.node.serialization.kryo.CordaClassResolver
import net.corda.node.serialization.kryo.CordaKryo import net.corda.node.serialization.kryo.CordaKryo
@ -11,6 +15,7 @@ import net.corda.node.serialization.kryo.DefaultKryoCustomizer
import net.corda.node.serialization.kryo.kryoMagic import net.corda.node.serialization.kryo.kryoMagic
import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.rigorousMock
import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Before import org.junit.Before
import org.junit.Rule import org.junit.Rule
@ -18,16 +23,18 @@ import org.junit.Test
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
class SerializationTokenTest { class SerializationTokenTest {
@Rule @Rule
@JvmField @JvmField
val testSerialization = SerializationEnvironmentRule() val testCheckpointSerialization = CheckpointSerializationEnvironmentRule()
private lateinit var factory: SerializationFactory
private lateinit var context: SerializationContext private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: CheckpointSerializationContext
@Before @Before
fun setup() { fun setup() {
factory = testSerialization.serializationFactory factory = testCheckpointSerialization.checkpointSerializationFactory
context = testSerialization.checkpointContext.withWhitelisted(SingletonSerializationToken::class.java) context = testCheckpointSerialization.checkpointSerializationContext.withWhitelisted(SingletonSerializationToken::class.java)
} }
// Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized // Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized
@ -42,16 +49,16 @@ class SerializationTokenTest {
override fun equals(other: Any?) = other is LargeTokenizable && other.bytes.size == this.bytes.size override fun equals(other: Any?) = other is LargeTokenizable && other.bytes.size == this.bytes.size
} }
private fun serializeAsTokenContext(toBeTokenized: Any) = SerializeAsTokenContextImpl(toBeTokenized, factory, context, rigorousMock()) private fun serializeAsTokenContext(toBeTokenized: Any) = CheckpointSerializeAsTokenContextImpl(toBeTokenized, factory, context, rigorousMock())
@Test @Test
fun `write token and read tokenizable`() { fun `write token and read tokenizable`() {
val tokenizableBefore = LargeTokenizable() val tokenizableBefore = LargeTokenizable()
val context = serializeAsTokenContext(tokenizableBefore) val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context) val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.serialize(factory, testContext) val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext)
assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes) assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes)
val tokenizableAfter = serializedBytes.deserialize(factory, testContext) val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext)
assertThat(tokenizableAfter).isSameAs(tokenizableBefore) assertThat(tokenizableAfter).isSameAs(tokenizableBefore)
} }
@ -62,8 +69,8 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken() val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore) val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context) val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.serialize(factory, testContext) val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext)
val tokenizableAfter = serializedBytes.deserialize(factory, testContext) val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext)
assertThat(tokenizableAfter).isSameAs(tokenizableBefore) assertThat(tokenizableAfter).isSameAs(tokenizableBefore)
} }
@ -72,7 +79,7 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken() val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(emptyList<Any>()) val context = serializeAsTokenContext(emptyList<Any>())
val testContext = this.context.withTokenContext(context) val testContext = this.context.withTokenContext(context)
tokenizableBefore.serialize(factory, testContext) tokenizableBefore.checkpointSerialize(factory, testContext)
} }
@Test(expected = UnsupportedOperationException::class) @Test(expected = UnsupportedOperationException::class)
@ -80,14 +87,14 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken() val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(emptyList<Any>()) val context = serializeAsTokenContext(emptyList<Any>())
val testContext = this.context.withTokenContext(context) val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList<Any>())).serialize(factory, testContext) val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList<Any>())).checkpointSerialize(factory, testContext)
serializedBytes.deserialize(factory, testContext) serializedBytes.checkpointDeserialize(factory, testContext)
} }
@Test(expected = KryoException::class) @Test(expected = KryoException::class)
fun `no context set`() { fun `no context set`() {
val tokenizableBefore = UnitSerializeAsToken() val tokenizableBefore = UnitSerializeAsToken()
tokenizableBefore.serialize(factory, context) tokenizableBefore.checkpointSerialize(factory, context)
} }
@Test(expected = KryoException::class) @Test(expected = KryoException::class)
@ -105,7 +112,7 @@ class SerializationTokenTest {
kryo.writeObject(it, emptyList<Any>()) kryo.writeObject(it, emptyList<Any>())
} }
val serializedBytes = SerializedBytes<Any>(stream.toByteArray()) val serializedBytes = SerializedBytes<Any>(stream.toByteArray())
serializedBytes.deserialize(factory, testContext) serializedBytes.checkpointDeserialize(factory, testContext)
} }
private class WrongTypeSerializeAsToken : SerializeAsToken { private class WrongTypeSerializeAsToken : SerializeAsToken {
@ -121,7 +128,7 @@ class SerializationTokenTest {
val tokenizableBefore = WrongTypeSerializeAsToken() val tokenizableBefore = WrongTypeSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore) val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context) val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.serialize(factory, testContext) val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext)
serializedBytes.deserialize(factory, testContext) serializedBytes.checkpointDeserialize(factory, testContext)
} }
} }

View File

@ -5,6 +5,7 @@ import com.nhaarman.mockito_kotlin.doAnswer
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.DoNotImplement import net.corda.core.DoNotImplement
import net.corda.core.internal.staticField import net.corda.core.internal.staticField
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.internal.SerializationEnvironment import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.effectiveSerializationEnv import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.testing.common.internal.asContextEnv import net.corda.testing.common.internal.asContextEnv
@ -45,7 +46,6 @@ class SerializationEnvironmentRule(private val inheritable: Boolean = false) : T
private lateinit var env: SerializationEnvironment private lateinit var env: SerializationEnvironment
val serializationFactory get() = env.serializationFactory val serializationFactory get() = env.serializationFactory
val checkpointContext get() = env.checkpointContext
override fun apply(base: Statement, description: Description): Statement { override fun apply(base: Statement, description: Description): Statement {
init(description.toString()) init(description.toString())

View File

@ -0,0 +1,71 @@
package net.corda.testing.core.internal
import com.nhaarman.mockito_kotlin.any
import com.nhaarman.mockito_kotlin.doAnswer
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.internal.staticField
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.testing.common.internal.asContextEnv
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.internal.createTestSerializationEnv
import net.corda.testing.internal.inVMExecutors
import net.corda.testing.internal.rigorousMock
import net.corda.testing.internal.testThreadFactory
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnector
import org.junit.rules.TestRule
import org.junit.runner.Description
import org.junit.runners.model.Statement
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
/**
* A test checkpoint serialization rule implementation for use in tests.
*
* @param inheritable whether new threads inherit the environment, use sparingly.
*/
class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule {
companion object {
init {
// Can't turn it off, and it creates threads that do serialization, so hack it:
InVMConnector::class.staticField<ExecutorService>("threadPoolExecutor").value = rigorousMock<ExecutorService>().also {
doAnswer {
inVMExecutors.computeIfAbsent(effectiveSerializationEnv) {
Executors.newCachedThreadPool(testThreadFactory(true)) // Close enough to what InVMConnector makes normally.
}.execute(it.arguments[0] as Runnable)
}.whenever(it).execute(any())
}
}
/** Do not call, instead use [SerializationEnvironmentRule] as a [org.junit.Rule]. */
fun <T> run(taskLabel: String, task: (SerializationEnvironment) -> T): T {
return CheckpointSerializationEnvironmentRule().apply { init(taskLabel) }.runTask(task)
}
}
private lateinit var env: SerializationEnvironment
override fun apply(base: Statement, description: Description): Statement {
init(description.toString())
return object : Statement() {
override fun evaluate() = runTask { base.evaluate() }
}
}
private fun init(envLabel: String) {
env = createTestSerializationEnv(envLabel)
}
private fun <T> runTask(task: (SerializationEnvironment) -> T): T {
try {
return env.asContextEnv(inheritable, task)
} finally {
inVMExecutors.remove(env)
}
}
val checkpointSerializationFactory get() = env.checkpointSerializationFactory
val checkpointSerializationContext get() = env.checkpointContext
}

View File

@ -4,10 +4,11 @@ import com.nhaarman.mockito_kotlin.doNothing
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme
import net.corda.core.DoNotImplement import net.corda.core.DoNotImplement
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.* import net.corda.core.serialization.internal.*
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoServerSerializationScheme import net.corda.node.serialization.kryo.KryoSerializationScheme
import net.corda.serialization.internal.* import net.corda.serialization.internal.*
import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.SerializationEnvironmentRule
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
@ -33,8 +34,6 @@ internal fun createTestSerializationEnv(label: String): SerializationEnvironment
val factory = SerializationFactoryImpl().apply { val factory = SerializationFactoryImpl().apply {
registerScheme(AMQPClientSerializationScheme(emptyList())) registerScheme(AMQPClientSerializationScheme(emptyList()))
registerScheme(AMQPServerSerializationScheme(emptyList())) registerScheme(AMQPServerSerializationScheme(emptyList()))
// needed for checkpointing
registerScheme(KryoServerSerializationScheme())
} }
return object : SerializationEnvironmentImpl( return object : SerializationEnvironmentImpl(
factory, factory,
@ -42,7 +41,8 @@ internal fun createTestSerializationEnv(label: String): SerializationEnvironment
AMQP_RPC_SERVER_CONTEXT, AMQP_RPC_SERVER_CONTEXT,
AMQP_RPC_CLIENT_CONTEXT, AMQP_RPC_CLIENT_CONTEXT,
AMQP_STORAGE_CONTEXT, AMQP_STORAGE_CONTEXT,
KRYO_CHECKPOINT_CONTEXT KRYO_CHECKPOINT_CONTEXT,
CheckpointSerializationFactory(KryoSerializationScheme)
) { ) {
override fun toString() = "testSerializationEnv($label)" override fun toString() = "testSerializationEnv($label)"
} }

View File

@ -17,10 +17,7 @@ import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.core.utilities.base64ToByteArray import net.corda.core.utilities.base64ToByteArray
import net.corda.core.utilities.hexToByteArray import net.corda.core.utilities.hexToByteArray
import net.corda.core.utilities.sequence import net.corda.core.utilities.sequence
import net.corda.serialization.internal.AMQP_P2P_CONTEXT import net.corda.serialization.internal.*
import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT
import net.corda.serialization.internal.CordaSerializationMagic
import net.corda.serialization.internal.SerializationFactoryImpl
import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme
import net.corda.serialization.internal.amqp.DeserializationInput import net.corda.serialization.internal.amqp.DeserializationInput
import net.corda.serialization.internal.amqp.amqpMagic import net.corda.serialization.internal.amqp.amqpMagic

View File

@ -3,6 +3,7 @@ package net.corda.bootstrapper.serialization
import net.corda.core.serialization.internal.SerializationEnvironmentImpl import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.serialization.internal.AMQP_P2P_CONTEXT import net.corda.serialization.internal.AMQP_P2P_CONTEXT
import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT
import net.corda.serialization.internal.SerializationFactoryImpl import net.corda.serialization.internal.SerializationFactoryImpl
@ -20,7 +21,7 @@ class SerializationEngine {
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), rpcServerContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader), storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),
checkpointContext = AMQP_P2P_CONTEXT.withClassLoader(classloader) checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader)
) )
} }
} }