Introduce current context concept for serialization in preparation for WireTransaction changes (#1448)

This commit is contained in:
Rick Parker 2017-09-08 08:16:38 +01:00 committed by GitHub
parent 6bf2871819
commit 79f1e1ae7f
7 changed files with 138 additions and 34 deletions

View File

@ -3,8 +3,6 @@ package net.corda.core.serialization
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.sha256
import net.corda.core.internal.WriteOnceProperty
import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.sequence
@ -13,7 +11,7 @@ import net.corda.core.utilities.sequence
* An abstraction for serializing and deserializing objects, with support for versioning of the wire format via
* a header / prefix in the bytes.
*/
interface SerializationFactory {
abstract class SerializationFactory {
/**
* Deserialize the bytes in to an object, using the prefixed bytes to determine the format.
*
@ -21,7 +19,7 @@ interface SerializationFactory {
* @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown.
* @param context A context that configures various parameters to deserialization.
*/
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T
abstract fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T
/**
* Serialize an object to bytes using the preferred serialization format version from the context.
@ -29,7 +27,63 @@ interface SerializationFactory {
* @param obj The object to be serialized.
* @param context A context that configures various parameters to serialization, including the serialization format version.
*/
fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T>
abstract fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T>
/**
* If there is a need to nest serialization/deserialization with a modified context during serialization or deserialization,
* this will return the current context used to start serialization/deserialization.
*/
val currentContext: SerializationContext? get() = _currentContext.get()
/**
* A context to use as a default if you do not require a specially configured context. It will be the current context
* if the use is somehow nested (see [currentContext]).
*/
val defaultContext: SerializationContext get() = currentContext ?: SerializationDefaults.P2P_CONTEXT
private val _currentContext = ThreadLocal<SerializationContext?>()
/**
* Change the current context inside the block to that supplied.
*/
fun <T> withCurrentContext(context: SerializationContext?, block: () -> T): T {
val priorContext = _currentContext.get()
if (context != null) _currentContext.set(context)
try {
return block()
} finally {
if (context != null) _currentContext.set(priorContext)
}
}
/**
* Allow subclasses to temporarily mark themselves as the current factory for the current thread during serialization/deserialization.
* Will restore the prior context on exiting the block.
*/
protected fun <T> asCurrent(block: SerializationFactory.() -> T): T {
val priorContext = _currentFactory.get()
_currentFactory.set(this)
try {
return block()
} finally {
_currentFactory.set(priorContext)
}
}
companion object {
private val _currentFactory = ThreadLocal<SerializationFactory?>()
/**
* A default factory for serialization/deserialization, taking into account the [currentFactory] if set.
*/
val defaultFactory: SerializationFactory get() = currentFactory ?: SerializationDefaults.SERIALIZATION_FACTORY
/**
* If there is a need to nest serialization/deserialization with a modified context during serialization or deserialization,
* this will return the current factory used to start serialization/deserialization.
*/
val currentFactory: SerializationFactory? get() = _currentFactory.get()
}
}
/**
@ -76,6 +130,12 @@ interface SerializationContext {
*/
fun withClassLoader(classLoader: ClassLoader): SerializationContext
/**
* Helper method to return a new context based on this context with the appropriate class loader constructed from the passed attachment identifiers.
* (Requires the attachment storage to have been enabled).
*/
fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): SerializationContext
/**
* Helper method to return a new context based on this context with the given class specifically whitelisted.
*/
@ -107,26 +167,26 @@ object SerializationDefaults {
/**
* Convenience extension method for deserializing a ByteSequence, utilising the defaults.
*/
inline fun <reified T : Any> ByteSequence.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): T {
inline fun <reified T : Any> ByteSequence.deserialize(serializationFactory: SerializationFactory = SerializationFactory.defaultFactory, context: SerializationContext = serializationFactory.defaultContext): T {
return serializationFactory.deserialize(this, T::class.java, context)
}
/**
* Convenience extension method for deserializing SerializedBytes with type matching, utilising the defaults.
*/
inline fun <reified T : Any> SerializedBytes<T>.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): T {
inline fun <reified T : Any> SerializedBytes<T>.deserialize(serializationFactory: SerializationFactory = SerializationFactory.defaultFactory, context: SerializationContext = serializationFactory.defaultContext): T {
return serializationFactory.deserialize(this, T::class.java, context)
}
/**
* Convenience extension method for deserializing a ByteArray, utilising the defaults.
*/
inline fun <reified T : Any> ByteArray.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): T = this.sequence().deserialize(serializationFactory, context)
inline fun <reified T : Any> ByteArray.deserialize(serializationFactory: SerializationFactory = SerializationFactory.defaultFactory, context: SerializationContext = serializationFactory.defaultContext): T = this.sequence().deserialize(serializationFactory, context)
/**
* Convenience extension method for serializing an object of type T, utilising the defaults.
*/
fun <T : Any> T.serialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): SerializedBytes<T> {
fun <T : Any> T.serialize(serializationFactory: SerializationFactory = SerializationFactory.defaultFactory, context: SerializationContext = serializationFactory.defaultContext): SerializedBytes<T> {
return serializationFactory.serialize(this, context)
}
@ -142,4 +202,4 @@ class SerializedBytes<T : Any>(bytes: ByteArray) : OpaqueBytes(bytes) {
interface ClassWhitelist {
fun hasListed(type: Class<*>): Boolean
}
}

View File

@ -4,7 +4,7 @@ import net.corda.core.contracts.*
import net.corda.core.crypto.*
import net.corda.core.identity.Party
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.serialize
import java.nio.ByteBuffer
import java.util.function.Predicate
@ -22,12 +22,12 @@ fun <T : Any> serializedHash(x: T, privacySalt: PrivacySalt?, index: Int): Secur
fun <T : Any> serializedHash(x: T, nonce: SecureHash): SecureHash {
return if (x !is PrivacySalt) // PrivacySalt is not required to have an accompanied nonce.
(x.serialize(context = P2P_CONTEXT.withoutReferences()).bytes + nonce.bytes).sha256()
(x.serialize(context = SerializationFactory.defaultFactory.defaultContext.withoutReferences()).bytes + nonce.bytes).sha256()
else
serializedHash(x)
}
fun <T : Any> serializedHash(x: T): SecureHash = x.serialize(context = P2P_CONTEXT.withoutReferences()).bytes.sha256()
fun <T : Any> serializedHash(x: T): SecureHash = x.serialize(context = SerializationFactory.defaultFactory.defaultContext.withoutReferences()).bytes.sha256()
/** The nonce is computed as Hash(privacySalt || index). */
fun computeNonce(privacySalt: PrivacySalt, index: Int) = (privacySalt.bytes + ByteBuffer.allocate(4).putInt(index).array()).sha256()

View File

@ -8,12 +8,11 @@ import com.esotericsoftware.kryo.serializers.CompatibleFieldSerializer
import com.esotericsoftware.kryo.serializers.FieldSerializer
import com.esotericsoftware.kryo.util.MapReferenceResolver
import net.corda.core.contracts.*
import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.TransactionSignature
import net.corda.core.crypto.CompositeKey
import net.corda.core.identity.Party
import net.corda.core.internal.VisibleForTesting
import net.corda.core.serialization.AttachmentsClassLoader
import net.corda.core.serialization.MissingAttachmentsException
import net.corda.core.serialization.SerializeAsTokenContext
@ -241,9 +240,6 @@ fun Input.readBytesWithLength(): ByteArray {
/** A serialisation engine that knows how to deserialise code inside a sandbox */
@ThreadSafe
object WireTransactionSerializer : Serializer<WireTransaction>() {
@VisibleForTesting
internal val attachmentsClassLoaderEnabled = "attachments.class.loader.enabled"
override fun write(kryo: Kryo, output: Output, obj: WireTransaction) {
kryo.writeClassAndObject(output, obj.inputs)
kryo.writeClassAndObject(output, obj.attachments)
@ -255,7 +251,7 @@ object WireTransactionSerializer : Serializer<WireTransaction>() {
}
private fun attachmentsClassLoader(kryo: Kryo, attachmentHashes: List<SecureHash>): ClassLoader? {
kryo.context[attachmentsClassLoaderEnabled] as? Boolean ?: false || return null
kryo.context[attachmentsClassLoaderEnabledPropertyName] as? Boolean ?: false || return null
val serializationContext = kryo.serializationContext() ?: return null // Some tests don't set one.
val missing = ArrayList<SecureHash>()
val attachments = ArrayList<Attachment>()

View File

@ -8,6 +8,10 @@ import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.cache.Cache
import com.google.common.cache.CacheBuilder
import net.corda.core.contracts.Attachment
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.LazyPool
import net.corda.core.serialization.*
import net.corda.core.utilities.ByteSequence
@ -17,6 +21,8 @@ import java.io.NotSerializableException
import java.util.*
import java.util.concurrent.ConcurrentHashMap
val attachmentsClassLoaderEnabledPropertyName = "attachments.class.loader.enabled"
object NotSupportedSeralizationScheme : SerializationScheme {
private fun doThrow(): Nothing = throw UnsupportedOperationException("Serialization scheme not supported.")
@ -33,6 +39,24 @@ data class SerializationContextImpl(override val preferredSerializationVersion:
override val properties: Map<Any, Any>,
override val objectReferencesEnabled: Boolean,
override val useCase: SerializationContext.UseCase) : SerializationContext {
private val cache: Cache<List<SecureHash>, AttachmentsClassLoader> = CacheBuilder.newBuilder().weakValues().maximumSize(1024).build()
// We need to cache the AttachmentClassLoaders to avoid too many contexts, since the class loader is part of cache key for the context.
override fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): SerializationContext {
properties[attachmentsClassLoaderEnabledPropertyName] as? Boolean ?: false || return this
val serializationContext = properties[serializationContextKey] as? SerializeAsTokenContextImpl ?: return this // Some tests don't set one.
return withClassLoader(cache.get(attachmentHashes) {
val missing = ArrayList<SecureHash>()
val attachments = ArrayList<Attachment>()
attachmentHashes.forEach { id ->
serializationContext.serviceHub.attachments.openAttachment(id)?.let { attachments += it } ?: run { missing += id }
}
missing.isNotEmpty() && throw MissingAttachmentsException(missing)
AttachmentsClassLoader(attachments)
})
}
override fun withProperty(property: Any, value: Any): SerializationContext {
return copy(properties = properties + (property to value))
}
@ -56,7 +80,7 @@ data class SerializationContextImpl(override val preferredSerializationVersion:
private const val HEADER_SIZE: Int = 8
open class SerializationFactoryImpl : SerializationFactory {
open class SerializationFactoryImpl : SerializationFactory() {
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
private val registeredSchemes: MutableCollection<SerializationScheme> = Collections.synchronizedCollection(mutableListOf())
@ -75,10 +99,12 @@ open class SerializationFactoryImpl : SerializationFactory {
}
@Throws(NotSerializableException::class)
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T = schemeFor(byteSequence, context.useCase).deserialize(byteSequence, clazz, context)
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
return asCurrent { withCurrentContext(context) { schemeFor(byteSequence, context.useCase).deserialize(byteSequence, clazz, context) } }
}
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
return schemeFor(context.preferredSerializationVersion, context.useCase).serialize(obj, context)
return asCurrent { withCurrentContext(context) { schemeFor(context.preferredSerializationVersion, context.useCase).serialize(obj, context) } }
}
fun registerScheme(scheme: SerializationScheme) {

View File

@ -10,13 +10,13 @@ import net.corda.core.internal.declaredField
import net.corda.core.node.ServiceHub
import net.corda.core.node.services.AttachmentStorage
import net.corda.core.serialization.*
import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT
import net.corda.core.serialization.SerializationFactory
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes
import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl
import net.corda.nodeapi.internal.serialization.WireTransactionSerializer
import net.corda.nodeapi.internal.serialization.attachmentsClassLoaderEnabledPropertyName
import net.corda.nodeapi.internal.serialization.withTokenContext
import net.corda.testing.DUMMY_NOTARY
import net.corda.testing.MEGA_CORP
@ -51,7 +51,7 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
private fun SerializationContext.withAttachmentStorage(attachmentStorage: AttachmentStorage): SerializationContext {
val serviceHub = mock<ServiceHub>()
whenever(serviceHub.attachments).thenReturn(attachmentStorage)
return this.withTokenContext(SerializeAsTokenContextImpl(serviceHub) {}).withProperty(WireTransactionSerializer.attachmentsClassLoaderEnabled, true)
return this.withTokenContext(SerializeAsTokenContextImpl(serviceHub) {}).withProperty(attachmentsClassLoaderEnabledPropertyName, true)
}
}
@ -223,7 +223,7 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader)
val context = P2P_CONTEXT.withClassLoader(cl).withWhitelisted(contract.javaClass)
val context = SerializationFactory.defaultFactory.defaultContext.withClassLoader(cl).withWhitelisted(contract.javaClass)
val state2 = bytes.deserialize(context = context)
assertTrue(state2.javaClass.classLoader is AttachmentsClassLoader)
assertNotNull(state2)
@ -239,7 +239,7 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
assertNotNull(data.contract)
val context2 = P2P_CONTEXT.withWhitelisted(data.contract.javaClass)
val context2 = SerializationFactory.defaultFactory.defaultContext.withWhitelisted(data.contract.javaClass)
val bytes = data.serialize(context = context2)
@ -251,7 +251,7 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader)
val context = P2P_CONTEXT.withClassLoader(cl).withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl))
val context = SerializationFactory.defaultFactory.defaultContext.withClassLoader(cl).withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl))
val state2 = bytes.deserialize(context = context)
assertEquals(cl, state2.contract.javaClass.classLoader)
@ -260,7 +260,7 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
// We should be able to load same class from a different class loader and have them be distinct.
val cl2 = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader)
val context3 = P2P_CONTEXT.withClassLoader(cl2).withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl2))
val context3 = SerializationFactory.defaultFactory.defaultContext.withClassLoader(cl2).withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl2))
val state3 = bytes.deserialize(context = context3)
assertEquals(cl2, state3.contract.javaClass.classLoader)
@ -312,7 +312,7 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
val contract = contractClass.newInstance() as DummyContractBackdoor
val tx = contract.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY)
val storage = MockAttachmentStorage()
val context = P2P_CONTEXT.withWhitelisted(contract.javaClass)
val context = SerializationFactory.defaultFactory.defaultContext.withWhitelisted(contract.javaClass)
.withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$State", true, child))
.withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$Commands\$Create", true, child))
.withAttachmentStorage(storage)
@ -346,13 +346,13 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
val wireTransaction = tx.toWireTransaction()
wireTransaction.serialize(context = P2P_CONTEXT.withAttachmentStorage(storage))
wireTransaction.serialize(context = SerializationFactory.defaultFactory.defaultContext.withAttachmentStorage(storage))
}
// use empty attachmentStorage
val e = assertFailsWith(MissingAttachmentsException::class) {
val mockAttStorage = MockAttachmentStorage()
bytes.deserialize(context = P2P_CONTEXT.withAttachmentStorage(mockAttStorage))
bytes.deserialize(context = SerializationFactory.defaultFactory.defaultContext.withAttachmentStorage(mockAttStorage))
if(mockAttStorage.openAttachment(attachmentRef) == null) {
throw MissingAttachmentsException(listOf(attachmentRef))
@ -360,4 +360,21 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
}
assertEquals(attachmentRef, e.ids.single())
}
@Test
fun `test loading a class from attachment during deserialization`() {
val child = ClassLoaderForTests()
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, child)
val contract = contractClass.newInstance() as DummyContractBackdoor
val storage = MockAttachmentStorage()
val attachmentRef = importJar(storage)
val outboundContext = SerializationFactory.defaultFactory.defaultContext.withClassLoader(child)
// We currently ignore annotations in attachments, so manually whitelist.
val inboundContext = SerializationFactory.defaultFactory.defaultContext.withWhitelisted(contract.javaClass).withAttachmentStorage(storage).withAttachmentsClassLoader(listOf(attachmentRef))
// Serialize with custom context to avoid populating the default context with the specially loaded class
val serialized = contract.serialize(context = outboundContext)
// Then deserialize with the attachment class loader associated with the attachment
serialized.deserialize(context = inboundContext)
}
}

View File

@ -76,7 +76,7 @@ class DefaultSerializableSerializer : Serializer<DefaultSerializable>() {
}
class CordaClassResolverTests {
val factory: SerializationFactory = object : SerializationFactory {
val factory: SerializationFactory = object : SerializationFactory() {
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
}

View File

@ -1,6 +1,7 @@
package net.corda.testing
import net.corda.client.rpc.serialization.KryoClientSerializationScheme
import net.corda.core.crypto.SecureHash
import net.corda.core.serialization.*
import net.corda.core.utilities.ByteSequence
import net.corda.node.serialization.KryoServerSerializationScheme
@ -89,7 +90,7 @@ fun resetTestSerialization() {
(SerializationDefaults.CHECKPOINT_CONTEXT as TestSerializationContext).delegate = null
}
class TestSerializationFactory : SerializationFactory {
class TestSerializationFactory : SerializationFactory() {
var delegate: SerializationFactory? = null
set(value) {
field = value
@ -150,4 +151,8 @@ class TestSerializationContext : SerializationContext {
override fun withPreferredSerializationVersion(versionHeader: ByteSequence): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withPreferredSerializationVersion(versionHeader) }
}
override fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withAttachmentsClassLoader(attachmentHashes) }
}
}