Merge community-master

This commit is contained in:
Michal Kit
2017-08-15 12:04:09 +01:00
649 changed files with 21249 additions and 14346 deletions

View File

@ -6,7 +6,7 @@ import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.NodeInfo
import net.corda.core.node.services.ServiceType
import net.corda.core.read
import net.corda.core.internal.read
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.NetworkHostAndPort

View File

@ -1,7 +1,6 @@
package net.corda.nodeapi
import com.esotericsoftware.kryo.pool.KryoPool
import net.corda.core.serialization.KryoPoolWithContext
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.Try
@ -96,12 +95,12 @@ object RPCApi {
val methodName: String,
val arguments: List<Any?>
) : ClientToServer() {
fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) {
fun writeToClientMessage(context: SerializationContext, message: ClientMessage) {
MessageUtil.setJMSReplyTo(message, clientAddress)
message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REQUEST.ordinal)
message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong)
message.putStringProperty(METHOD_NAME_FIELD_NAME, methodName)
message.bodyBuffer.writeBytes(arguments.serialize(kryoPool).bytes)
message.bodyBuffer.writeBytes(arguments.serialize(context = context).bytes)
}
}
@ -119,14 +118,14 @@ object RPCApi {
}
companion object {
fun fromClientMessage(kryoPool: KryoPool, message: ClientMessage): ClientToServer {
fun fromClientMessage(context: SerializationContext, message: ClientMessage): ClientToServer {
val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)]
return when (tag) {
RPCApi.ClientToServer.Tag.RPC_REQUEST -> RpcRequest(
clientAddress = MessageUtil.getJMSReplyTo(message),
id = RpcRequestId(message.getLongProperty(RPC_ID_FIELD_NAME)),
methodName = message.getStringProperty(METHOD_NAME_FIELD_NAME),
arguments = message.getBodyAsByteArray().deserialize(kryoPool)
arguments = message.getBodyAsByteArray().deserialize(context = context)
)
RPCApi.ClientToServer.Tag.OBSERVABLES_CLOSED -> {
val ids = ArrayList<ObservableId>()
@ -148,48 +147,48 @@ object RPCApi {
OBSERVATION
}
abstract fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage)
abstract fun writeToClientMessage(context: SerializationContext, message: ClientMessage)
data class RpcReply(
val id: RpcRequestId,
val result: Try<Any?>
) : ServerToClient() {
override fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) {
override fun writeToClientMessage(context: SerializationContext, message: ClientMessage) {
message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REPLY.ordinal)
message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong)
message.bodyBuffer.writeBytes(result.serialize(kryoPool).bytes)
message.bodyBuffer.writeBytes(result.serialize(context = context).bytes)
}
}
data class Observation(
val id: ObservableId,
val content: Notification<Any>
val content: Notification<*>
) : ServerToClient() {
override fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) {
override fun writeToClientMessage(context: SerializationContext, message: ClientMessage) {
message.putIntProperty(TAG_FIELD_NAME, Tag.OBSERVATION.ordinal)
message.putLongProperty(OBSERVABLE_ID_FIELD_NAME, id.toLong)
message.bodyBuffer.writeBytes(content.serialize(kryoPool).bytes)
message.bodyBuffer.writeBytes(content.serialize(context = context).bytes)
}
}
companion object {
fun fromClientMessage(kryoPool: KryoPool, message: ClientMessage): ServerToClient {
fun fromClientMessage(context: SerializationContext, message: ClientMessage): ServerToClient {
val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)]
return when (tag) {
RPCApi.ServerToClient.Tag.RPC_REPLY -> {
val id = RpcRequestId(message.getLongProperty(RPC_ID_FIELD_NAME))
val poolWithIdContext = KryoPoolWithContext(kryoPool, RpcRequestOrObservableIdKey, id.toLong)
val poolWithIdContext = context.withProperty(RpcRequestOrObservableIdKey, id.toLong)
RpcReply(
id = id,
result = message.getBodyAsByteArray().deserialize(poolWithIdContext)
result = message.getBodyAsByteArray().deserialize(context = poolWithIdContext)
)
}
RPCApi.ServerToClient.Tag.OBSERVATION -> {
val id = ObservableId(message.getLongProperty(OBSERVABLE_ID_FIELD_NAME))
val poolWithIdContext = KryoPoolWithContext(kryoPool, RpcRequestOrObservableIdKey, id.toLong)
val poolWithIdContext = context.withProperty(RpcRequestOrObservableIdKey, id.toLong)
Observation(
id = id,
content = message.getBodyAsByteArray().deserialize(poolWithIdContext)
content = message.getBodyAsByteArray().deserialize(context = poolWithIdContext)
)
}
}

View File

@ -4,13 +4,16 @@ package net.corda.nodeapi
import com.esotericsoftware.kryo.Registration
import com.esotericsoftware.kryo.Serializer
import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.requireExternal
import net.corda.core.serialization.*
import net.corda.core.concurrent.CordaFuture
import net.corda.core.CordaRuntimeException
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory
import net.corda.core.toFuture
import net.corda.core.toObservable
import net.corda.core.CordaRuntimeException
import net.corda.nodeapi.config.OldConfig
import net.corda.nodeapi.internal.serialization.*
import rx.Observable
import java.io.InputStream
@ -46,16 +49,15 @@ class PermissionException(msg: String) : RuntimeException(msg)
// The Kryo used for the RPC wire protocol. Every type in the wire protocol is listed here explicitly.
// This is annoying to write out, but will make it easier to formalise the wire protocol when the time comes,
// because we can see everything we're using in one place.
class RPCKryo(observableSerializer: Serializer<Observable<Any>>) : CordaKryo(makeStandardClassResolver()) {
class RPCKryo(observableSerializer: Serializer<Observable<*>>, val serializationFactory: SerializationFactory, val serializationContext: SerializationContext) : CordaKryo(CordaClassResolver(serializationFactory, serializationContext)) {
init {
DefaultKryoCustomizer.customize(this)
// RPC specific classes
register(InputStream::class.java, InputStreamSerializer)
register(Observable::class.java, observableSerializer)
@Suppress("UNCHECKED_CAST")
register(ListenableFuture::class,
read = { kryo, input -> observableSerializer.read(kryo, input, Observable::class.java as Class<Observable<Any>>).toFuture() },
register(CordaFuture::class,
read = { kryo, input -> observableSerializer.read(kryo, input, Observable::class.java).toFuture() },
write = { kryo, output, obj -> observableSerializer.write(kryo, output, obj.toObservable()) }
)
}
@ -67,10 +69,14 @@ class RPCKryo(observableSerializer: Serializer<Observable<Any>>) : CordaKryo(mak
if (InputStream::class.java != type && InputStream::class.java.isAssignableFrom(type)) {
return super.getRegistration(InputStream::class.java)
}
if (ListenableFuture::class.java != type && ListenableFuture::class.java.isAssignableFrom(type)) {
return super.getRegistration(ListenableFuture::class.java)
if (CordaFuture::class.java != type && CordaFuture::class.java.isAssignableFrom(type)) {
return super.getRegistration(CordaFuture::class.java)
}
type.requireExternal("RPC not allowed to deserialise internal classes")
return super.getRegistration(type)
}
private fun Class<*>.requireExternal(msg: String) {
require(!name.startsWith("net.corda.node.") && !name.contains(".internal.")) { "$msg: $name" }
}
}

View File

@ -2,7 +2,7 @@ package net.corda.nodeapi.config
import com.typesafe.config.Config
import com.typesafe.config.ConfigUtil
import net.corda.core.noneOrSingle
import net.corda.core.internal.noneOrSingle
import net.corda.core.utilities.validateX500Name
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.parseNetworkHostAndPort
@ -73,7 +73,7 @@ private fun Config.getSingleValue(path: String, type: KType): Any? {
Path::class -> Paths.get(getString(path))
URL::class -> URL(getString(path))
Properties::class -> getConfig(path).toProperties()
X500Name::class -> X500Name(getString(path)).apply(::validateX500Name)
X500Name::class -> X500Name(getString(path))
else -> if (typeClass.java.isEnum) {
parseEnum(typeClass.java, getString(path))
} else {

View File

@ -1,6 +1,6 @@
package net.corda.nodeapi.config
import net.corda.core.div
import net.corda.core.internal.div
import java.nio.file.Path
interface SSLConfiguration {

View File

@ -0,0 +1,107 @@
package net.corda.nodeapi.internal.serialization
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.internal.serialization.amqp.AmqpHeaderV1_0
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
import java.util.concurrent.ConcurrentHashMap
internal val AMQP_ENABLED get() = SerializationDefaults.P2P_CONTEXT.preferedSerializationVersion == AmqpHeaderV1_0
abstract class AbstractAMQPSerializationScheme : SerializationScheme {
internal companion object {
fun registerCustomSerializers(factory: SerializerFactory) {
factory.apply {
register(net.corda.nodeapi.internal.serialization.amqp.custom.PublicKeySerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.ThrowableSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.X500NameSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.BigDecimalSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.CurrencySerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.InstantSerializer(this))
}
}
}
private val serializerFactoriesForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, SerializerFactory>()
protected abstract fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory
protected abstract fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory
private fun getSerializerFactory(context: SerializationContext): SerializerFactory {
return serializerFactoriesForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) {
when (context.useCase) {
SerializationContext.UseCase.Checkpoint ->
throw IllegalStateException("AMQP should not be used for checkpoint serialization.")
SerializationContext.UseCase.RPCClient ->
rpcClientSerializerFactory(context)
SerializationContext.UseCase.RPCServer ->
rpcServerSerializerFactory(context)
else -> SerializerFactory(context.whitelist) // TODO pass class loader also
}
}.also { registerCustomSerializers(it) }
}
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
val serializerFactory = getSerializerFactory(context)
return DeserializationInput(serializerFactory).deserialize(byteSequence, clazz)
}
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
val serializerFactory = getSerializerFactory(context)
return SerializationOutput(serializerFactory).serialize(obj)
}
protected fun canDeserializeVersion(byteSequence: ByteSequence): Boolean = AMQP_ENABLED && byteSequence == AmqpHeaderV1_0
}
// TODO: This will eventually cover server RPC as well and move to node module, but for now this is not implemented
class AMQPServerSerializationScheme : AbstractAMQPSerializationScheme() {
override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory {
throw UnsupportedOperationException()
}
override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean {
return (canDeserializeVersion(byteSequence) &&
(target == SerializationContext.UseCase.P2P || target == SerializationContext.UseCase.Storage))
}
}
// TODO: This will eventually cover client RPC as well and move to client module, but for now this is not implemented
class AMQPClientSerializationScheme : AbstractAMQPSerializationScheme() {
override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory {
throw UnsupportedOperationException()
}
override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean {
return (canDeserializeVersion(byteSequence) &&
(target == SerializationContext.UseCase.P2P || target == SerializationContext.UseCase.Storage))
}
}
val AMQP_P2P_CONTEXT = SerializationContextImpl(AmqpHeaderV1_0,
SerializationDefaults.javaClass.classLoader,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(),
true,
SerializationContext.UseCase.P2P)
val AMQP_STORAGE_CONTEXT = SerializationContextImpl(AmqpHeaderV1_0,
SerializationDefaults.javaClass.classLoader,
AllButBlacklisted,
emptyMap(),
true,
SerializationContext.UseCase.Storage)

View File

@ -0,0 +1,135 @@
package net.corda.nodeapi.internal.serialization
import net.corda.core.serialization.ClassWhitelist
import sun.misc.Unsafe
import sun.security.util.Password
import java.io.*
import java.lang.invoke.*
import java.lang.reflect.AccessibleObject
import java.lang.reflect.Modifier
import java.lang.reflect.Parameter
import java.lang.reflect.ReflectPermission
import java.net.DatagramSocket
import java.net.ServerSocket
import java.net.Socket
import java.net.URLConnection
import java.security.AccessController
import java.security.KeyStore
import java.security.Permission
import java.security.Provider
import java.sql.Connection
import java.util.*
import java.util.logging.Handler
import java.util.zip.ZipFile
import kotlin.collections.HashSet
import kotlin.collections.LinkedHashSet
/**
* This is a [ClassWhitelist] implementation where everything is whitelisted except for blacklisted classes and interfaces.
* In practice, as flows are arbitrary code in which it is convenient to do many things,
* we can often end up pulling in a lot of objects that do not make sense to put in a checkpoint.
* Thus, by blacklisting classes/interfaces we don't expect to be serialised, we can better handle/monitor the aforementioned behaviour.
* Inheritance works for blacklisted items, but one can specifically exclude classes from blacklisting as well.
*/
object AllButBlacklisted : ClassWhitelist {
private val blacklistedClasses = hashSetOf<String>(
// Known blacklisted classes.
Thread::class.java.name,
HashSet::class.java.name,
HashMap::class.java.name,
ClassLoader::class.java.name,
Handler::class.java.name, // MemoryHandler, StreamHandler
Runtime::class.java.name,
Unsafe::class.java.name,
ZipFile::class.java.name,
Provider::class.java.name,
SecurityManager::class.java.name,
Random::class.java.name,
// Known blacklisted interfaces.
Connection::class.java.name,
// TODO: AutoCloseable::class.java.name,
// java.security.
KeyStore::class.java.name,
Password::class.java.name,
AccessController::class.java.name,
Permission::class.java.name,
// java.net.
DatagramSocket::class.java.name,
ServerSocket::class.java.name,
Socket::class.java.name,
URLConnection::class.java.name,
// TODO: add more from java.net.
// java.io.
Console::class.java.name,
File::class.java.name,
FileDescriptor::class.java.name,
FilePermission::class.java.name,
RandomAccessFile::class.java.name,
Reader::class.java.name,
Writer::class.java.name,
// TODO: add more from java.io.
// java.lang.invoke classes.
CallSite::class.java.name, // for all CallSites eg MutableCallSite, VolatileCallSite etc.
LambdaMetafactory::class.java.name,
MethodHandle::class.java.name,
MethodHandleProxies::class.java.name,
MethodHandles::class.java.name,
MethodHandles.Lookup::class.java.name,
MethodType::class.java.name,
SerializedLambda::class.java.name,
SwitchPoint::class.java.name,
// java.lang.invoke interfaces.
MethodHandleInfo::class.java.name,
// java.lang.invoke exceptions.
LambdaConversionException::class.java.name,
WrongMethodTypeException::class.java.name,
// java.lang.reflect.
AccessibleObject::class.java.name, // For Executable, Field, Method, Constructor.
Modifier::class.java.name,
Parameter::class.java.name,
ReflectPermission::class.java.name
// TODO: add more from java.lang.reflect.
)
// Specifically exclude classes from the blacklist,
// even if any of their superclasses and/or implemented interfaces are blacklisted.
private val forciblyAllowedClasses = hashSetOf<String>(
LinkedHashSet::class.java.name,
LinkedHashMap::class.java.name,
InputStream::class.java.name,
BufferedInputStream::class.java.name,
Class.forName("sun.net.www.protocol.jar.JarURLConnection\$JarURLInputStream").name
)
/**
* This implementation supports inheritance; thus, if a superclass or superinterface is blacklisted, so is the input class.
*/
override fun hasListed(type: Class<*>): Boolean {
// Check if excluded.
if (type.name !in forciblyAllowedClasses) {
// Check if listed.
if (type.name in blacklistedClasses)
throw IllegalStateException("Class ${type.name} is blacklisted, so it cannot be used in serialization.")
// Inheritance check.
else {
val aMatch = blacklistedClasses.firstOrNull { Class.forName(it).isAssignableFrom(type) }
if (aMatch != null) {
// TODO: blacklistedClasses += type.name // add it, so checking is faster next time we encounter this class.
val matchType = if (Class.forName(aMatch).isInterface) "superinterface" else "superclass"
throw IllegalStateException("The $matchType $aMatch of ${type.name} is blacklisted, so it cannot be used in serialization.")
}
}
}
return true
}
}

View File

@ -0,0 +1,232 @@
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.*
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.serializers.FieldSerializer
import com.esotericsoftware.kryo.util.DefaultClassResolver
import com.esotericsoftware.kryo.util.Util
import net.corda.core.serialization.*
import net.corda.core.utilities.loggerFor
import net.corda.nodeapi.internal.serialization.amqp.AmqpHeaderV1_0
import java.io.PrintWriter
import java.lang.reflect.Modifier.isAbstract
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.nio.file.Paths
import java.nio.file.StandardOpenOption
import java.util.*
fun Kryo.addToWhitelist(type: Class<*>) {
((classResolver as? CordaClassResolver)?.whitelist as? MutableClassWhitelist)?.add(type)
}
/**
* @param amqpEnabled Setting this to true turns on experimental AMQP serialization for any class annotated with
* [CordaSerializable].
*/
class CordaClassResolver(val serializationFactory: SerializationFactory, val serializationContext: SerializationContext) : DefaultClassResolver() {
val whitelist: ClassWhitelist = TransientClassWhiteList(serializationContext.whitelist)
/** Returns the registration for the specified class, or null if the class is not registered. */
override fun getRegistration(type: Class<*>): Registration? {
return super.getRegistration(type) ?: checkClass(type)
}
private var whitelistEnabled = true
fun disableWhitelist() {
whitelistEnabled = false
}
fun enableWhitelist() {
whitelistEnabled = true
}
private fun checkClass(type: Class<*>): Registration? {
// If call path has disabled whitelisting (see [CordaKryo.register]), just return without checking.
if (!whitelistEnabled) return null
// Allow primitives, abstracts and interfaces
if (type.isPrimitive || type == Any::class.java || isAbstract(type.modifiers) || type == String::class.java) return null
// If array, recurse on element type
if (type.isArray) return checkClass(type.componentType)
// Specialised enum entry, so just resolve the parent Enum type since cannot annotate the specialised entry.
if (!type.isEnum && Enum::class.java.isAssignableFrom(type)) return checkClass(type.superclass)
// Kotlin lambdas require some special treatment
if (kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type)) return null
// It's safe to have the Class already, since Kryo loads it with initialisation off.
// If we use a whitelist with blacklisting capabilities, whitelist.hasListed(type) may throw an IllegalStateException if input class is blacklisted.
// Thus, blacklisting precedes annotation checking.
if (!whitelist.hasListed(type) && !checkForAnnotation(type)) {
throw KryoException("Class ${Util.className(type)} is not annotated or on the whitelist, so cannot be used in serialization")
}
return null
}
override fun registerImplicit(type: Class<*>): Registration {
// If something is not annotated, or AMQP is disabled, we stay serializing with Kryo. This will typically be the
// case for flow checkpoints (ignoring all cases where AMQP is disabled) since our top level messaging data structures
// are annotated and once we enter AMQP serialisation we stay with it for the entire object subgraph.
if (checkForAnnotation(type) && AMQP_ENABLED) {
// Build AMQP serializer
return register(Registration(type, KryoAMQPSerializer(serializationFactory, serializationContext), NAME.toInt()))
}
val objectInstance = try {
type.kotlin.objectInstance
} catch (t: Throwable) {
null // objectInstance will throw if the type is something like a lambda
}
// We have to set reference to true, since the flag influences how String fields are treated and we want it to be consistent.
val references = kryo.references
try {
kryo.references = true
val serializer = if (objectInstance != null) {
KotlinObjectSerializer(objectInstance)
} else if (kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type)) {
// Kotlin lambdas extend this class and any captured variables are stored in synthentic fields
FieldSerializer<Any>(kryo, type).apply { setIgnoreSyntheticFields(false) }
} else {
kryo.getDefaultSerializer(type)
}
return register(Registration(type, serializer, NAME.toInt()))
} finally {
kryo.references = references
}
}
// Trivial Serializer which simply returns the given instance, which we already know is a Kotlin object
private class KotlinObjectSerializer(private val objectInstance: Any) : Serializer<Any>() {
override fun read(kryo: Kryo, input: Input, type: Class<Any>): Any = objectInstance
override fun write(kryo: Kryo, output: Output, obj: Any) = Unit
}
// We don't allow the annotation for classes in attachments for now. The class will be on the main classpath if we have the CorDapp installed.
// We also do not allow extension of KryoSerializable for annotated classes, or combination with @DefaultSerializer for custom serialisation.
// TODO: Later we can support annotations on attachment classes and spin up a proxy via bytecode that we know is harmless.
private fun checkForAnnotation(type: Class<*>): Boolean {
return (type.classLoader !is AttachmentsClassLoader)
&& !KryoSerializable::class.java.isAssignableFrom(type)
&& !type.isAnnotationPresent(DefaultSerializer::class.java)
&& (type.isAnnotationPresent(CordaSerializable::class.java) || hasInheritedAnnotation(type))
}
// Recursively check interfaces for our annotation.
private fun hasInheritedAnnotation(type: Class<*>): Boolean {
return type.interfaces.any { it.isAnnotationPresent(CordaSerializable::class.java) || hasInheritedAnnotation(it) }
|| (type.superclass != null && hasInheritedAnnotation(type.superclass))
}
// Need to clear out class names from attachments.
override fun reset() {
super.reset()
// Kryo creates a cache of class name to Class<*> which does not work so well with multiple class loaders.
// TODO: come up with a more efficient way. e.g. segregate the name space by class loader.
if (nameToClass != null) {
val classesToRemove: MutableList<String> = ArrayList(nameToClass.size)
for (entry in nameToClass.entries()) {
if (entry.value.classLoader is AttachmentsClassLoader) {
classesToRemove += entry.key
}
}
for (className in classesToRemove) {
nameToClass.remove(className)
}
}
}
}
interface MutableClassWhitelist : ClassWhitelist {
fun add(entry: Class<*>)
}
object EmptyWhitelist : ClassWhitelist {
override fun hasListed(type: Class<*>): Boolean = false
}
class BuiltInExceptionsWhitelist : ClassWhitelist {
override fun hasListed(type: Class<*>): Boolean = Throwable::class.java.isAssignableFrom(type) && type.`package`.name.startsWith("java.")
}
object AllWhitelist : ClassWhitelist {
override fun hasListed(type: Class<*>): Boolean = true
}
// TODO: Need some concept of from which class loader
class GlobalTransientClassWhiteList(val delegate: ClassWhitelist) : MutableClassWhitelist, ClassWhitelist by delegate {
companion object {
val whitelist: MutableSet<String> = Collections.synchronizedSet(mutableSetOf())
}
override fun hasListed(type: Class<*>): Boolean {
return (type.name in whitelist) || delegate.hasListed(type)
}
override fun add(entry: Class<*>) {
whitelist += entry.name
}
}
/**
* A whitelist that can be customised via the [CordaPluginRegistry], since implements [MutableClassWhitelist].
*/
class TransientClassWhiteList(val delegate: ClassWhitelist) : MutableClassWhitelist, ClassWhitelist by delegate {
val whitelist: MutableSet<String> = Collections.synchronizedSet(mutableSetOf())
override fun hasListed(type: Class<*>): Boolean {
return (type.name in whitelist) || delegate.hasListed(type)
}
override fun add(entry: Class<*>) {
whitelist += entry.name
}
}
/**
* This class is not currently used, but can be installed to log a large number of missing entries from the whitelist
* and was used to track down the initial set.
*/
@Suppress("unused")
class LoggingWhitelist(val delegate: ClassWhitelist, val global: Boolean = true) : MutableClassWhitelist {
companion object {
val log = loggerFor<LoggingWhitelist>()
val globallySeen: MutableSet<String> = Collections.synchronizedSet(mutableSetOf())
val journalWriter: PrintWriter? = openOptionalDynamicWhitelistJournal()
private fun openOptionalDynamicWhitelistJournal(): PrintWriter? {
val fileName = System.getenv("WHITELIST_FILE")
if (fileName != null && fileName.isNotEmpty()) {
try {
return PrintWriter(Files.newBufferedWriter(Paths.get(fileName), StandardCharsets.UTF_8, StandardOpenOption.CREATE, StandardOpenOption.APPEND, StandardOpenOption.WRITE), true)
} catch(ioEx: Exception) {
log.error("Could not open/create whitelist journal file for append: $fileName", ioEx)
}
}
return null
}
}
private val locallySeen: MutableSet<String> = mutableSetOf()
private val alreadySeen: MutableSet<String> get() = if (global) globallySeen else locallySeen
override fun hasListed(type: Class<*>): Boolean {
if (type.name !in alreadySeen && !delegate.hasListed(type)) {
alreadySeen += type.name
val className = Util.className(type)
log.warn("Dynamically whitelisted class $className")
journalWriter?.println(className)
}
return true
}
override fun add(entry: Class<*>) {
if (delegate is MutableClassWhitelist) {
delegate.add(entry)
} else {
throw UnsupportedOperationException("Cannot add to whitelist since delegate whitelist is not mutable.")
}
}
}

View File

@ -0,0 +1,159 @@
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.serializers.CompatibleFieldSerializer
import com.esotericsoftware.kryo.serializers.FieldSerializer
import com.esotericsoftware.kryo.util.MapReferenceResolver
import de.javakaffee.kryoserializers.ArraysAsListSerializer
import de.javakaffee.kryoserializers.BitSetSerializer
import de.javakaffee.kryoserializers.UnmodifiableCollectionsSerializer
import de.javakaffee.kryoserializers.guava.*
import net.corda.core.crypto.composite.CompositeKey
import net.corda.core.node.CordaPluginRegistry
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.toNonEmptySet
import net.i2p.crypto.eddsa.EdDSAPrivateKey
import net.i2p.crypto.eddsa.EdDSAPublicKey
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey
import org.bouncycastle.jcajce.provider.asymmetric.rsa.BCRSAPrivateCrtKey
import org.bouncycastle.jcajce.provider.asymmetric.rsa.BCRSAPublicKey
import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey
import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PublicKey
import org.objenesis.instantiator.ObjectInstantiator
import org.objenesis.strategy.InstantiatorStrategy
import org.objenesis.strategy.StdInstantiatorStrategy
import org.slf4j.Logger
import sun.security.ec.ECPublicKeyImpl
import sun.security.provider.certpath.X509CertPath
import java.io.BufferedInputStream
import java.io.FileInputStream
import java.io.InputStream
import java.lang.reflect.Modifier.isPublic
import java.security.cert.CertPath
import java.util.*
import kotlin.collections.ArrayList
object DefaultKryoCustomizer {
private val pluginRegistries: List<CordaPluginRegistry> by lazy {
ServiceLoader.load(CordaPluginRegistry::class.java, this.javaClass.classLoader).toList()
}
fun customize(kryo: Kryo): Kryo {
return kryo.apply {
// Store a little schema of field names in the stream the first time a class is used which increases tolerance
// for change to a class.
setDefaultSerializer(CompatibleFieldSerializer::class.java)
// Take the safest route here and allow subclasses to have fields named the same as super classes.
fieldSerializerConfig.cachedFieldNameStrategy = FieldSerializer.CachedFieldNameStrategy.EXTENDED
instantiatorStrategy = CustomInstantiatorStrategy()
// WARNING: reordering the registrations here will cause a change in the serialized form, since classes
// with custom serializers get written as registration ids. This will break backwards-compatibility.
// Please add any new registrations to the end.
// TODO: re-organise registrations into logical groups before v1.0
register(Arrays.asList("").javaClass, ArraysAsListSerializer())
register(SignedTransaction::class.java, SignedTransactionSerializer)
register(WireTransaction::class.java, WireTransactionSerializer)
register(SerializedBytes::class.java, SerializedBytesSerializer)
UnmodifiableCollectionsSerializer.registerSerializers(this)
ImmutableListSerializer.registerSerializers(this)
ImmutableSetSerializer.registerSerializers(this)
ImmutableSortedSetSerializer.registerSerializers(this)
ImmutableMapSerializer.registerSerializers(this)
ImmutableMultimapSerializer.registerSerializers(this)
// InputStream subclasses whitelisting, required for attachments.
register(BufferedInputStream::class.java, InputStreamSerializer)
register(Class.forName("sun.net.www.protocol.jar.JarURLConnection\$JarURLInputStream"), InputStreamSerializer)
noReferencesWithin<WireTransaction>()
register(ECPublicKeyImpl::class.java, ECPublicKeyImplSerializer)
register(EdDSAPublicKey::class.java, Ed25519PublicKeySerializer)
register(EdDSAPrivateKey::class.java, Ed25519PrivateKeySerializer)
// Using a custom serializer for compactness
register(CompositeKey::class.java, CompositeKeySerializer)
// Exceptions. We don't bother sending the stack traces as the client will fill in its own anyway.
register(Array<StackTraceElement>::class, read = { _, _ -> emptyArray() }, write = { _, _, _ -> })
// This ensures a NonEmptySetSerializer is constructed with an initial value.
register(NonEmptySet::class.java, NonEmptySetSerializer)
addDefaultSerializer(SerializeAsToken::class.java, SerializeAsTokenSerializer<SerializeAsToken>())
register(BitSet::class.java, BitSetSerializer())
register(Class::class.java, ClassSerializer)
addDefaultSerializer(Logger::class.java, LoggerSerializer)
register(FileInputStream::class.java, InputStreamSerializer)
// Required for HashCheckingStream (de)serialization.
// Note that return type should be specifically set to InputStream, otherwise it may not work, i.e. val aStream : InputStream = HashCheckingStream(...).
addDefaultSerializer(InputStream::class.java, InputStreamSerializer)
register(CertPath::class.java, CertPathSerializer)
register(X509CertPath::class.java, CertPathSerializer)
register(X500Name::class.java, X500NameSerializer)
register(X509CertificateHolder::class.java, X509CertificateSerializer)
register(BCECPrivateKey::class.java, PrivateKeySerializer)
register(BCECPublicKey::class.java, PublicKeySerializer)
register(BCRSAPrivateCrtKey::class.java, PrivateKeySerializer)
register(BCRSAPublicKey::class.java, PublicKeySerializer)
register(BCSphincs256PrivateKey::class.java, PrivateKeySerializer)
register(BCSphincs256PublicKey::class.java, PublicKeySerializer)
register(sun.security.ec.ECPublicKeyImpl::class.java, PublicKeySerializer)
register(NotaryChangeWireTransaction::class.java, NotaryChangeWireTransactionSerializer)
val customization = KryoSerializationCustomization(this)
pluginRegistries.forEach { it.customizeSerialization(customization) }
}
}
private class CustomInstantiatorStrategy : InstantiatorStrategy {
private val fallbackStrategy = StdInstantiatorStrategy()
// Use this to allow construction of objects using a JVM backdoor that skips invoking the constructors, if there
// is no no-arg constructor available.
private val defaultStrategy = Kryo.DefaultInstantiatorStrategy(fallbackStrategy)
override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> {
// However this doesn't work for non-public classes in the java. namespace
val strat = if (type.name.startsWith("java.") && !isPublic(type.modifiers)) fallbackStrategy else defaultStrategy
return strat.newInstantiatorOf(type)
}
}
private object NonEmptySetSerializer : Serializer<NonEmptySet<Any>>() {
override fun write(kryo: Kryo, output: Output, obj: NonEmptySet<Any>) {
// Write out the contents as normal
output.writeInt(obj.size, true)
obj.forEach { kryo.writeClassAndObject(output, it) }
}
override fun read(kryo: Kryo, input: Input, type: Class<NonEmptySet<Any>>): NonEmptySet<Any> {
val size = input.readInt(true)
require(size >= 1) { "Invalid size read off the wire: $size" }
val list = ArrayList<Any>(size)
repeat(size) {
list += kryo.readClassAndObject(input)
}
return list.toNonEmptySet()
}
}
}

View File

@ -1,4 +1,4 @@
package net.corda.nodeapi.serialization
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.KryoException
import net.corda.core.node.CordaPluginRegistry

View File

@ -0,0 +1,572 @@
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.*
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.util.MapReferenceResolver
import net.corda.core.contracts.*
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.TransactionSignature
import net.corda.core.crypto.composite.CompositeKey
import net.corda.core.identity.Party
import net.corda.core.internal.VisibleForTesting
import net.corda.core.serialization.AttachmentsClassLoader
import net.corda.core.serialization.MissingAttachmentsException
import net.corda.core.serialization.SerializeAsTokenContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.SgxSupport
import net.i2p.crypto.eddsa.EdDSAPrivateKey
import net.i2p.crypto.eddsa.EdDSAPublicKey
import net.i2p.crypto.eddsa.spec.EdDSANamedCurveSpec
import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec
import net.i2p.crypto.eddsa.spec.EdDSAPublicKeySpec
import org.bouncycastle.asn1.ASN1InputStream
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.cert.X509CertificateHolder
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import sun.security.ec.ECPublicKeyImpl
import sun.security.util.DerValue
import java.io.ByteArrayInputStream
import java.io.InputStream
import java.lang.reflect.InvocationTargetException
import java.security.PrivateKey
import java.security.PublicKey
import java.security.cert.CertPath
import java.security.cert.CertificateFactory
import java.util.*
import javax.annotation.concurrent.ThreadSafe
import kotlin.reflect.KClass
import kotlin.reflect.KMutableProperty
import kotlin.reflect.KParameter
import kotlin.reflect.full.memberProperties
import kotlin.reflect.full.primaryConstructor
import kotlin.reflect.jvm.isAccessible
import kotlin.reflect.jvm.javaType
/**
* Serialization utilities, using the Kryo framework with a custom serialiser for immutable data classes and a dead
* simple, totally non-extensible binary (sub)format.
*
* This is NOT what should be used in any final platform product, rather, the final state should be a precisely
* specified and standardised binary format with attention paid to anti-malleability, versioning and performance.
* FIX SBE is a potential candidate: it prioritises performance over convenience and was designed for HFT. Google
* Protocol Buffers with a minor tightening to make field reordering illegal is another possibility.
*
* FIX SBE:
* https://real-logic.github.io/simple-binary-encoding/
* http://mechanical-sympathy.blogspot.co.at/2014/05/simple-binary-encoding.html
* Protocol buffers:
* https://developers.google.com/protocol-buffers/
*
* But for now we use Kryo to maximise prototyping speed.
*
* Note that this code ignores *ALL* concerns beyond convenience, in particular it ignores:
*
* - Performance
* - Security
*
* This code will happily deserialise literally anything, including malicious streams that would reconstruct classes
* in invalid states, thus violating system invariants. It isn't designed to handle malicious streams and therefore,
* isn't usable beyond the prototyping stage. But that's fine: we can revisit serialisation technologies later after
* a formal evaluation process.
*
* We now distinguish between internal, storage related Kryo and external, network facing Kryo. We presently use
* some non-whitelisted classes as part of internal storage.
* TODO: eliminate internal, storage related whitelist issues, such as private keys in blob storage.
*/
/**
* A serialiser that avoids writing the wrapper class to the byte stream, thus ensuring [SerializedBytes] is a pure
* type safety hack.
*/
object SerializedBytesSerializer : Serializer<SerializedBytes<Any>>() {
override fun write(kryo: Kryo, output: Output, obj: SerializedBytes<Any>) {
output.writeVarInt(obj.bytes.size, true)
output.writeBytes(obj.bytes)
}
override fun read(kryo: Kryo, input: Input, type: Class<SerializedBytes<Any>>): SerializedBytes<Any> {
return SerializedBytes(input.readBytes(input.readVarInt(true)))
}
}
/**
* Serializes properties and deserializes by using the constructor. This assumes that all backed properties are
* set via the constructor and the class is immutable.
*/
class ImmutableClassSerializer<T : Any>(val klass: KClass<T>) : Serializer<T>() {
val props by lazy { klass.memberProperties.sortedBy { it.name } }
val propsByName by lazy { props.associateBy { it.name } }
val constructor by lazy { klass.primaryConstructor!! }
init {
// Verify that this class is immutable (all properties are final).
// We disable this check inside SGX as the reflection blows up.
if (!SgxSupport.isInsideEnclave) {
assert(props.none { it is KMutableProperty<*> })
}
}
// Just a utility to help us catch cases where nodes are running out of sync versions.
private fun hashParameters(params: List<KParameter>): Int {
return params.map {
(it.name ?: "") + it.index.toString() + it.type.javaType.typeName
}.hashCode()
}
override fun write(kryo: Kryo, output: Output, obj: T) {
output.writeVarInt(constructor.parameters.size, true)
output.writeInt(hashParameters(constructor.parameters))
for (param in constructor.parameters) {
val kProperty = propsByName[param.name!!]!!
kProperty.isAccessible = true
when (param.type.javaType.typeName) {
"int" -> output.writeVarInt(kProperty.get(obj) as Int, true)
"long" -> output.writeVarLong(kProperty.get(obj) as Long, true)
"short" -> output.writeShort(kProperty.get(obj) as Int)
"char" -> output.writeChar(kProperty.get(obj) as Char)
"byte" -> output.writeByte(kProperty.get(obj) as Byte)
"double" -> output.writeDouble(kProperty.get(obj) as Double)
"float" -> output.writeFloat(kProperty.get(obj) as Float)
"boolean" -> output.writeBoolean(kProperty.get(obj) as Boolean)
else -> try {
kryo.writeClassAndObject(output, kProperty.get(obj))
} catch (e: Exception) {
throw IllegalStateException("Failed to serialize ${param.name} in ${klass.qualifiedName}", e)
}
}
}
}
override fun read(kryo: Kryo, input: Input, type: Class<T>): T {
assert(type.kotlin == klass)
val numFields = input.readVarInt(true)
val fieldTypeHash = input.readInt()
// A few quick checks for data evolution. Note that this is not guaranteed to catch every problem! But it's
// good enough for a prototype.
if (numFields != constructor.parameters.size)
throw KryoException("Mismatch between number of constructor parameters and number of serialised fields " +
"for ${klass.qualifiedName} ($numFields vs ${constructor.parameters.size})")
if (fieldTypeHash != hashParameters(constructor.parameters))
throw KryoException("Hashcode mismatch for parameter types for ${klass.qualifiedName}: unsupported type evolution has happened.")
val args = arrayOfNulls<Any?>(numFields)
var cursor = 0
for (param in constructor.parameters) {
args[cursor++] = when (param.type.javaType.typeName) {
"int" -> input.readVarInt(true)
"long" -> input.readVarLong(true)
"short" -> input.readShort()
"char" -> input.readChar()
"byte" -> input.readByte()
"double" -> input.readDouble()
"float" -> input.readFloat()
"boolean" -> input.readBoolean()
else -> kryo.readClassAndObject(input)
}
}
// If the constructor throws an exception, pass it through instead of wrapping it.
return try {
constructor.call(*args)
} catch (e: InvocationTargetException) {
throw e.cause!!
}
}
}
// TODO This is a temporary inefficient serializer for sending InputStreams through RPC. This may be done much more
// efficiently using Artemis's large message feature.
object InputStreamSerializer : Serializer<InputStream>() {
override fun write(kryo: Kryo, output: Output, stream: InputStream) {
val buffer = ByteArray(4096)
while (true) {
val numberOfBytesRead = stream.read(buffer)
if (numberOfBytesRead != -1) {
output.writeInt(numberOfBytesRead, true)
output.writeBytes(buffer, 0, numberOfBytesRead)
} else {
output.writeInt(0)
break
}
}
}
override fun read(kryo: Kryo, input: Input, type: Class<InputStream>): InputStream {
val chunks = ArrayList<ByteArray>()
while (true) {
val chunk = input.readBytesWithLength()
if (chunk.isEmpty()) {
break
} else {
chunks.add(chunk)
}
}
val flattened = ByteArray(chunks.sumBy { it.size })
var offset = 0
for (chunk in chunks) {
System.arraycopy(chunk, 0, flattened, offset, chunk.size)
offset += chunk.size
}
return ByteArrayInputStream(flattened)
}
}
inline fun <T> Kryo.useClassLoader(cl: ClassLoader, body: () -> T): T {
val tmp = this.classLoader ?: ClassLoader.getSystemClassLoader()
this.classLoader = cl
try {
return body()
} finally {
this.classLoader = tmp
}
}
fun Output.writeBytesWithLength(byteArray: ByteArray) {
this.writeInt(byteArray.size, true)
this.writeBytes(byteArray)
}
fun Input.readBytesWithLength(): ByteArray {
val size = this.readInt(true)
return this.readBytes(size)
}
/** A serialisation engine that knows how to deserialise code inside a sandbox */
@ThreadSafe
object WireTransactionSerializer : Serializer<WireTransaction>() {
@VisibleForTesting
internal val attachmentsClassLoaderEnabled = "attachments.class.loader.enabled"
override fun write(kryo: Kryo, output: Output, obj: WireTransaction) {
kryo.writeClassAndObject(output, obj.inputs)
kryo.writeClassAndObject(output, obj.attachments)
kryo.writeClassAndObject(output, obj.outputs)
kryo.writeClassAndObject(output, obj.commands)
kryo.writeClassAndObject(output, obj.notary)
kryo.writeClassAndObject(output, obj.timeWindow)
kryo.writeClassAndObject(output, obj.privacySalt)
}
private fun attachmentsClassLoader(kryo: Kryo, attachmentHashes: List<SecureHash>): ClassLoader? {
kryo.context[attachmentsClassLoaderEnabled] as? Boolean ?: false || return null
val serializationContext = kryo.serializationContext() ?: return null // Some tests don't set one.
val missing = ArrayList<SecureHash>()
val attachments = ArrayList<Attachment>()
attachmentHashes.forEach { id ->
serializationContext.serviceHub.attachments.openAttachment(id)?.let { attachments += it } ?: run { missing += id }
}
missing.isNotEmpty() && throw MissingAttachmentsException(missing)
return AttachmentsClassLoader(attachments)
}
@Suppress("UNCHECKED_CAST")
override fun read(kryo: Kryo, input: Input, type: Class<WireTransaction>): WireTransaction {
val inputs = kryo.readClassAndObject(input) as List<StateRef>
val attachmentHashes = kryo.readClassAndObject(input) as List<SecureHash>
// If we're deserialising in the sandbox context, we use our special attachments classloader.
// Otherwise we just assume the code we need is on the classpath already.
kryo.useClassLoader(attachmentsClassLoader(kryo, attachmentHashes) ?: javaClass.classLoader) {
val outputs = kryo.readClassAndObject(input) as List<TransactionState<ContractState>>
val commands = kryo.readClassAndObject(input) as List<Command<*>>
val notary = kryo.readClassAndObject(input) as Party?
val timeWindow = kryo.readClassAndObject(input) as TimeWindow?
val privacySalt = kryo.readClassAndObject(input) as PrivacySalt
return WireTransaction(inputs, attachmentHashes, outputs, commands, notary, timeWindow, privacySalt)
}
}
}
@ThreadSafe
object NotaryChangeWireTransactionSerializer : Serializer<NotaryChangeWireTransaction>() {
override fun write(kryo: Kryo, output: Output, obj: NotaryChangeWireTransaction) {
kryo.writeClassAndObject(output, obj.inputs)
kryo.writeClassAndObject(output, obj.notary)
kryo.writeClassAndObject(output, obj.newNotary)
}
@Suppress("UNCHECKED_CAST")
override fun read(kryo: Kryo, input: Input, type: Class<NotaryChangeWireTransaction>): NotaryChangeWireTransaction {
val inputs = kryo.readClassAndObject(input) as List<StateRef>
val notary = kryo.readClassAndObject(input) as Party
val newNotary = kryo.readClassAndObject(input) as Party
return NotaryChangeWireTransaction(inputs, notary, newNotary)
}
}
@ThreadSafe
object SignedTransactionSerializer : Serializer<SignedTransaction>() {
override fun write(kryo: Kryo, output: Output, obj: SignedTransaction) {
kryo.writeClassAndObject(output, obj.txBits)
kryo.writeClassAndObject(output, obj.sigs)
}
@Suppress("UNCHECKED_CAST")
override fun read(kryo: Kryo, input: Input, type: Class<SignedTransaction>): SignedTransaction {
return SignedTransaction(
kryo.readClassAndObject(input) as SerializedBytes<CoreTransaction>,
kryo.readClassAndObject(input) as List<TransactionSignature>
)
}
}
/** For serialising an ed25519 private key */
@ThreadSafe
object Ed25519PrivateKeySerializer : Serializer<EdDSAPrivateKey>() {
override fun write(kryo: Kryo, output: Output, obj: EdDSAPrivateKey) {
check(obj.params == Crypto.EDDSA_ED25519_SHA512.algSpec)
output.writeBytesWithLength(obj.seed)
}
override fun read(kryo: Kryo, input: Input, type: Class<EdDSAPrivateKey>): EdDSAPrivateKey {
val seed = input.readBytesWithLength()
return EdDSAPrivateKey(EdDSAPrivateKeySpec(seed, Crypto.EDDSA_ED25519_SHA512.algSpec as EdDSANamedCurveSpec))
}
}
/** For serialising an ed25519 public key */
@ThreadSafe
object Ed25519PublicKeySerializer : Serializer<EdDSAPublicKey>() {
override fun write(kryo: Kryo, output: Output, obj: EdDSAPublicKey) {
check(obj.params == Crypto.EDDSA_ED25519_SHA512.algSpec)
output.writeBytesWithLength(obj.abyte)
}
override fun read(kryo: Kryo, input: Input, type: Class<EdDSAPublicKey>): EdDSAPublicKey {
val A = input.readBytesWithLength()
return EdDSAPublicKey(EdDSAPublicKeySpec(A, Crypto.EDDSA_ED25519_SHA512.algSpec as EdDSANamedCurveSpec))
}
}
/** For serialising an ed25519 public key */
@ThreadSafe
object ECPublicKeyImplSerializer : Serializer<ECPublicKeyImpl>() {
override fun write(kryo: Kryo, output: Output, obj: ECPublicKeyImpl) {
output.writeBytesWithLength(obj.encoded)
}
override fun read(kryo: Kryo, input: Input, type: Class<ECPublicKeyImpl>): ECPublicKeyImpl {
val A = input.readBytesWithLength()
val der = DerValue(A)
return ECPublicKeyImpl.parse(der) as ECPublicKeyImpl
}
}
// TODO Implement standardized serialization of CompositeKeys. See JIRA issue: CORDA-249.
@ThreadSafe
object CompositeKeySerializer : Serializer<CompositeKey>() {
override fun write(kryo: Kryo, output: Output, obj: CompositeKey) {
output.writeInt(obj.threshold)
output.writeInt(obj.children.size)
obj.children.forEach { kryo.writeClassAndObject(output, it) }
}
override fun read(kryo: Kryo, input: Input, type: Class<CompositeKey>): CompositeKey {
val threshold = input.readInt()
val children = readListOfLength<CompositeKey.NodeAndWeight>(kryo, input, minLen = 2)
val builder = CompositeKey.Builder()
children.forEach { builder.addKey(it.node, it.weight) }
return builder.build(threshold) as CompositeKey
}
}
@ThreadSafe
object PrivateKeySerializer : Serializer<PrivateKey>() {
override fun write(kryo: Kryo, output: Output, obj: PrivateKey) {
output.writeBytesWithLength(obj.encoded)
}
override fun read(kryo: Kryo, input: Input, type: Class<PrivateKey>): PrivateKey {
val A = input.readBytesWithLength()
return Crypto.decodePrivateKey(A)
}
}
/** For serialising a public key */
@ThreadSafe
object PublicKeySerializer : Serializer<PublicKey>() {
override fun write(kryo: Kryo, output: Output, obj: PublicKey) {
// TODO: Instead of encoding to the default X509 format, we could have a custom per key type (space-efficient) serialiser.
output.writeBytesWithLength(obj.encoded)
}
override fun read(kryo: Kryo, input: Input, type: Class<PublicKey>): PublicKey {
val A = input.readBytesWithLength()
return Crypto.decodePublicKey(A)
}
}
/**
* Helper function for reading lists with number of elements at the beginning.
* @param minLen minimum number of elements we expect for list to include, defaults to 1
* @param expectedLen expected length of the list, defaults to null if arbitrary length list read
*/
inline fun <reified T> readListOfLength(kryo: Kryo, input: Input, minLen: Int = 1, expectedLen: Int? = null): List<T> {
val elemCount = input.readInt()
if (elemCount < minLen) throw KryoException("Cannot deserialize list, too little elements. Minimum required: $minLen, got: $elemCount")
if (expectedLen != null && elemCount != expectedLen)
throw KryoException("Cannot deserialize list, expected length: $expectedLen, got: $elemCount.")
val list = (1..elemCount).map { kryo.readClassAndObject(input) as T }
return list
}
/**
* We need to disable whitelist checking during calls from our Kryo code to register a serializer, since it checks
* for existing registrations and then will enter our [CordaClassResolver.getRegistration] method.
*/
open class CordaKryo(classResolver: ClassResolver) : Kryo(classResolver, MapReferenceResolver()) {
override fun register(type: Class<*>?): Registration {
(classResolver as? CordaClassResolver)?.disableWhitelist()
try {
return super.register(type)
} finally {
(classResolver as? CordaClassResolver)?.enableWhitelist()
}
}
override fun register(type: Class<*>?, id: Int): Registration {
(classResolver as? CordaClassResolver)?.disableWhitelist()
try {
return super.register(type, id)
} finally {
(classResolver as? CordaClassResolver)?.enableWhitelist()
}
}
override fun register(type: Class<*>?, serializer: Serializer<*>?): Registration {
(classResolver as? CordaClassResolver)?.disableWhitelist()
try {
return super.register(type, serializer)
} finally {
(classResolver as? CordaClassResolver)?.enableWhitelist()
}
}
override fun register(registration: Registration?): Registration {
(classResolver as? CordaClassResolver)?.disableWhitelist()
try {
return super.register(registration)
} finally {
(classResolver as? CordaClassResolver)?.enableWhitelist()
}
}
}
inline fun <T : Any> Kryo.register(
type: KClass<T>,
crossinline read: (Kryo, Input) -> T,
crossinline write: (Kryo, Output, T) -> Unit): Registration {
return register(
type.java,
object : Serializer<T>() {
override fun read(kryo: Kryo, input: Input, clazz: Class<T>): T = read(kryo, input)
override fun write(kryo: Kryo, output: Output, obj: T) = write(kryo, output, obj)
}
)
}
/**
* Use this method to mark any types which can have the same instance within it more than once. This will make sure
* the serialised form is stable across multiple serialise-deserialise cycles. Using this on a type with internal cyclic
* references will throw a stack overflow exception during serialisation.
*/
inline fun <reified T : Any> Kryo.noReferencesWithin() {
register(T::class.java, NoReferencesSerializer(getSerializer(T::class.java)))
}
class NoReferencesSerializer<T>(val baseSerializer: Serializer<T>) : Serializer<T>() {
override fun read(kryo: Kryo, input: Input, type: Class<T>): T {
return kryo.withoutReferences { baseSerializer.read(kryo, input, type) }
}
override fun write(kryo: Kryo, output: Output, obj: T) {
kryo.withoutReferences { baseSerializer.write(kryo, output, obj) }
}
}
fun <T> Kryo.withoutReferences(block: () -> T): T {
val previousValue = setReferences(false)
try {
return block()
} finally {
references = previousValue
}
}
/** For serialising a Logger. */
@ThreadSafe
object LoggerSerializer : Serializer<Logger>() {
override fun write(kryo: Kryo, output: Output, obj: Logger) {
output.writeString(obj.name)
}
override fun read(kryo: Kryo, input: Input, type: Class<Logger>): Logger {
return LoggerFactory.getLogger(input.readString())
}
}
object ClassSerializer : Serializer<Class<*>>() {
override fun read(kryo: Kryo, input: Input, type: Class<Class<*>>): Class<*> {
val className = input.readString()
return Class.forName(className)
}
override fun write(kryo: Kryo, output: Output, clazz: Class<*>) {
output.writeString(clazz.name)
}
}
/**
* For serialising an [X500Name] without touching Sun internal classes.
*/
@ThreadSafe
object X500NameSerializer : Serializer<X500Name>() {
override fun read(kryo: Kryo, input: Input, type: Class<X500Name>): X500Name {
return X500Name.getInstance(ASN1InputStream(input.readBytes()).readObject())
}
override fun write(kryo: Kryo, output: Output, obj: X500Name) {
output.writeBytes(obj.encoded)
}
}
/**
* For serialising an [CertPath] in an X.500 standard format.
*/
@ThreadSafe
object CertPathSerializer : Serializer<CertPath>() {
val factory: CertificateFactory = CertificateFactory.getInstance("X.509")
override fun read(kryo: Kryo, input: Input, type: Class<CertPath>): CertPath {
return factory.generateCertPath(input)
}
override fun write(kryo: Kryo, output: Output, obj: CertPath) {
output.writeBytes(obj.encoded)
}
}
/**
* For serialising an [X509CertificateHolder] in an X.500 standard format.
*/
@ThreadSafe
object X509CertificateSerializer : Serializer<X509CertificateHolder>() {
override fun read(kryo: Kryo, input: Input, type: Class<X509CertificateHolder>): X509CertificateHolder {
return X509CertificateHolder(input.readBytes())
}
override fun write(kryo: Kryo, output: Output, obj: X509CertificateHolder) {
output.writeBytes(obj.encoded)
}
}
fun Kryo.serializationContext(): SerializeAsTokenContext? = context.get(serializationContextKey) as? SerializeAsTokenContext

View File

@ -0,0 +1,37 @@
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.sequence
import net.corda.nodeapi.internal.serialization.amqp.AmqpHeaderV1_0
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
/**
* This [Kryo] custom [Serializer] switches the object graph of anything annotated with `@CordaSerializable`
* to using the AMQP serialization wire format, and simply writes that out as bytes to the wire.
*
* There is no need to write out the length, since this can be peeked out of the first few bytes of the stream.
*/
class KryoAMQPSerializer(val serializationFactory: SerializationFactory, val serializationContext: SerializationContext) : Serializer<Any>() {
override fun write(kryo: Kryo, output: Output, obj: Any) {
val bytes = serializationFactory.serialize(obj, serializationContext.withPreferredSerializationVersion(AmqpHeaderV1_0)).bytes
// No need to write out the size since it's encoded within the AMQP.
output.write(bytes)
}
override fun read(kryo: Kryo, input: Input, type: Class<Any>): Any {
// Use our helper functions to peek the size of the serialized object out of the AMQP byte stream.
val peekedBytes = input.readBytes(DeserializationInput.BYTES_NEEDED_TO_PEEK)
val size = DeserializationInput.peekSize(peekedBytes)
val allBytes = peekedBytes.copyOf(size)
input.readBytes(allBytes, peekedBytes.size, size - peekedBytes.size)
return serializationFactory.deserialize(allBytes.sequence(), type, serializationContext)
}
}

View File

@ -0,0 +1,10 @@
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import net.corda.core.serialization.SerializationCustomization
class KryoSerializationCustomization(val kryo: Kryo) : SerializationCustomization {
override fun addToWhitelist(type: Class<*>) {
kryo.addToWhitelist(type)
}
}

View File

@ -0,0 +1,267 @@
package net.corda.nodeapi.internal.serialization
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.io.serialization.kryo.KryoSerializer
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import io.requery.util.CloseableIterator
import net.corda.core.internal.LazyPool
import net.corda.core.serialization.*
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes
import java.io.ByteArrayOutputStream
import java.io.NotSerializableException
import java.util.*
import java.util.concurrent.ConcurrentHashMap
object NotSupportedSeralizationScheme : SerializationScheme {
private fun doThrow(): Nothing = throw UnsupportedOperationException("Serialization scheme not supported.")
override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean = doThrow()
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T = doThrow()
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> = doThrow()
}
data class SerializationContextImpl(override val preferedSerializationVersion: ByteSequence,
override val deserializationClassLoader: ClassLoader,
override val whitelist: ClassWhitelist,
override val properties: Map<Any, Any>,
override val objectReferencesEnabled: Boolean,
override val useCase: SerializationContext.UseCase) : SerializationContext {
override fun withProperty(property: Any, value: Any): SerializationContext {
return copy(properties = properties + (property to value))
}
override fun withoutReferences(): SerializationContext {
return copy(objectReferencesEnabled = false)
}
override fun withClassLoader(classLoader: ClassLoader): SerializationContext {
return copy(deserializationClassLoader = classLoader)
}
override fun withWhitelisted(clazz: Class<*>): SerializationContext {
return copy(whitelist = object : ClassWhitelist {
override fun hasListed(type: Class<*>): Boolean = whitelist.hasListed(type) || type.name == clazz.name
})
}
override fun withPreferredSerializationVersion(versionHeader: ByteSequence) = copy(preferedSerializationVersion = versionHeader)
}
private const val HEADER_SIZE: Int = 8
open class SerializationFactoryImpl : SerializationFactory {
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
private val registeredSchemes: MutableCollection<SerializationScheme> = Collections.synchronizedCollection(mutableListOf())
// TODO: This is read-mostly. Probably a faster implementation to be found.
private val schemes: ConcurrentHashMap<Pair<ByteSequence, SerializationContext.UseCase>, SerializationScheme> = ConcurrentHashMap()
private fun schemeFor(byteSequence: ByteSequence, target: SerializationContext.UseCase): SerializationScheme {
// truncate sequence to 8 bytes, and make sure it's a copy to avoid holding onto large ByteArrays
return schemes.computeIfAbsent(byteSequence.take(HEADER_SIZE).copy() to target) {
for (scheme in registeredSchemes) {
if (scheme.canDeserializeVersion(it.first, it.second)) {
return@computeIfAbsent scheme
}
}
NotSupportedSeralizationScheme
}
}
@Throws(NotSerializableException::class)
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T = schemeFor(byteSequence, context.useCase).deserialize(byteSequence, clazz, context)
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
return schemeFor(context.preferedSerializationVersion, context.useCase).serialize(obj, context)
}
fun registerScheme(scheme: SerializationScheme) {
check(schemes.isEmpty()) { "All serialization schemes must be registered before any scheme is used." }
registeredSchemes += scheme
}
val alreadyRegisteredSchemes: Collection<SerializationScheme> get() = Collections.unmodifiableCollection(registeredSchemes)
override fun toString(): String {
return "${this.javaClass.name} registeredSchemes=$registeredSchemes ${creator.joinToString("\n")}"
}
override fun equals(other: Any?): Boolean {
return other is SerializationFactoryImpl &&
other.registeredSchemes == this.registeredSchemes
}
override fun hashCode(): Int = registeredSchemes.hashCode()
}
private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>() {
override fun write(kryo: Kryo, output: Output, closeable: AutoCloseable) {
val message = if (closeable is CloseableIterator<*>) {
"A live Iterator pointing to the database has been detected during flow checkpointing. This may be due " +
"to a Vault query - move it into a private method."
} else {
"${closeable.javaClass.name}, which is a closeable resource, has been detected during flow checkpointing. " +
"Restoring such resources across node restarts is not supported. Make sure code accessing it is " +
"confined to a private method or the reference is nulled out."
}
throw UnsupportedOperationException(message)
}
override fun read(kryo: Kryo, input: Input, type: Class<AutoCloseable>) = throw IllegalStateException("Should not reach here!")
}
abstract class AbstractKryoSerializationScheme(val serializationFactory: SerializationFactory) : SerializationScheme {
private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>()
protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool
protected abstract fun rpcServerKryoPool(context: SerializationContext): KryoPool
private fun getPool(context: SerializationContext): KryoPool {
return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) {
when (context.useCase) {
SerializationContext.UseCase.Checkpoint ->
KryoPool.Builder {
val serializer = Fiber.getFiberSerializer(false) as KryoSerializer
val classResolver = CordaClassResolver(serializationFactory, context).apply { setKryo(serializer.kryo) }
// TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that
val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true }
serializer.kryo.apply {
field.set(this, classResolver)
DefaultKryoCustomizer.customize(this)
addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector)
classLoader = it.second
}
}.build()
SerializationContext.UseCase.RPCClient ->
rpcClientKryoPool(context)
SerializationContext.UseCase.RPCServer ->
rpcServerKryoPool(context)
else ->
KryoPool.Builder {
DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(serializationFactory, context))).apply { classLoader = it.second }
}.build()
}
}
}
private fun <T : Any> withContext(kryo: Kryo, context: SerializationContext, block: (Kryo) -> T): T {
kryo.context.ensureCapacity(context.properties.size)
context.properties.forEach { kryo.context.put(it.key, it.value) }
try {
return block(kryo)
} finally {
kryo.context.clear()
}
}
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
val pool = getPool(context)
val headerSize = KryoHeaderV0_1.size
val header = byteSequence.take(headerSize)
if (header != KryoHeaderV0_1) {
throw KryoException("Serialized bytes header does not match expected format.")
}
Input(byteSequence.bytes, byteSequence.offset + headerSize, byteSequence.size - headerSize).use { input ->
return pool.run { kryo ->
withContext(kryo, context) {
@Suppress("UNCHECKED_CAST")
if (context.objectReferencesEnabled) {
kryo.readClassAndObject(input) as T
} else {
kryo.withoutReferences { kryo.readClassAndObject(input) as T }
}
}
}
}
}
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
val pool = getPool(context)
return pool.run { kryo ->
withContext(kryo, context) {
serializeOutputStreamPool.run { stream ->
serializeBufferPool.run { buffer ->
Output(buffer).use {
it.outputStream = stream
it.writeBytes(KryoHeaderV0_1.bytes)
if (context.objectReferencesEnabled) {
kryo.writeClassAndObject(it, obj)
} else {
kryo.withoutReferences { kryo.writeClassAndObject(it, obj) }
}
}
SerializedBytes(stream.toByteArray())
}
}
}
}
}
}
private val serializeBufferPool = LazyPool(
newInstance = { ByteArray(64 * 1024) }
)
private val serializeOutputStreamPool = LazyPool(
clear = ByteArrayOutputStream::reset,
shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large
newInstance = { ByteArrayOutputStream(64 * 1024) }
)
// "corda" + majorVersionByte + minorVersionMSB + minorVersionLSB
val KryoHeaderV0_1: OpaqueBytes = OpaqueBytes("corda\u0000\u0000\u0001".toByteArray(Charsets.UTF_8))
val KRYO_P2P_CONTEXT = SerializationContextImpl(KryoHeaderV0_1,
SerializationDefaults.javaClass.classLoader,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(),
true,
SerializationContext.UseCase.P2P)
val KRYO_RPC_SERVER_CONTEXT = SerializationContextImpl(KryoHeaderV0_1,
SerializationDefaults.javaClass.classLoader,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(),
true,
SerializationContext.UseCase.RPCServer)
val KRYO_RPC_CLIENT_CONTEXT = SerializationContextImpl(KryoHeaderV0_1,
SerializationDefaults.javaClass.classLoader,
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
emptyMap(),
true,
SerializationContext.UseCase.RPCClient)
val KRYO_STORAGE_CONTEXT = SerializationContextImpl(KryoHeaderV0_1,
SerializationDefaults.javaClass.classLoader,
AllButBlacklisted,
emptyMap(),
true,
SerializationContext.UseCase.Storage)
val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl(KryoHeaderV0_1,
SerializationDefaults.javaClass.classLoader,
QuasarWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.Checkpoint)
object QuasarWhitelist : ClassWhitelist {
override fun hasListed(type: Class<*>): Boolean = true
}
interface SerializationScheme {
// byteSequence expected to just be the 8 bytes necessary for versioning
fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T
@Throws(NotSerializableException::class)
fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T>
}

View File

@ -0,0 +1,56 @@
package net.corda.nodeapi.internal.serialization
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SerializeAsTokenContext
val serializationContextKey = SerializeAsTokenContext::class.java
fun SerializationContext.withTokenContext(serializationContext: SerializeAsTokenContext): SerializationContext = this.withProperty(serializationContextKey, serializationContext)
/**
* A context for mapping SerializationTokens to/from SerializeAsTokens.
*
* A context is initialised with an object containing all the instances of [SerializeAsToken] to eagerly register all the tokens.
* In our case this can be the [ServiceHub].
*
* Then it is a case of using the companion object methods on [SerializeAsTokenSerializer] to set and clear context as necessary
* when serializing to enable/disable tokenization.
*/
class SerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: SerializeAsTokenContext.() -> Unit) : SerializeAsTokenContext {
constructor(toBeTokenized: Any, serializationFactory: SerializationFactory, context: SerializationContext, serviceHub: ServiceHub) : this(serviceHub, {
serializationFactory.serialize(toBeTokenized, context.withTokenContext(this))
})
private val classNameToSingleton = mutableMapOf<String, SerializeAsToken>()
private var readOnly = false
init {
/**
* Go ahead and eagerly serialize the object to register all of the tokens in the context.
*
* This results in the toToken() method getting called for any [SingletonSerializeAsToken] instances which
* are encountered in the object graph as they are serialized and will therefore register the token to
* object mapping for those instances. We then immediately set the readOnly flag to stop further adhoc or
* accidental registrations from occuring as these could not be deserialized in a deserialization-first
* scenario if they are not part of this iniital context construction serialization.
*/
init(this)
readOnly = true
}
override fun putSingleton(toBeTokenized: SerializeAsToken) {
val className = toBeTokenized.javaClass.name
if (className !in classNameToSingleton) {
// Only allowable if we are in SerializeAsTokenContext init (readOnly == false)
if (readOnly) {
throw UnsupportedOperationException("Attempt to write token for lazy registered ${className}. All tokens should be registered during context construction.")
}
classNameToSingleton[className] = toBeTokenized
}
}
override fun getSingleton(className: String) = classNameToSingleton[className] ?: throw IllegalStateException("Unable to find tokenized instance of $className in context $this")
}

View File

@ -0,0 +1,25 @@
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import net.corda.core.internal.castIfPossible
import net.corda.core.serialization.SerializationToken
import net.corda.core.serialization.SerializeAsToken
/**
* A Kryo serializer for [SerializeAsToken] implementations.
*/
class SerializeAsTokenSerializer<T : SerializeAsToken> : Serializer<T>() {
override fun write(kryo: Kryo, output: Output, obj: T) {
kryo.writeClassAndObject(output, obj.toToken(kryo.serializationContext() ?: throw KryoException("Attempt to write a ${SerializeAsToken::class.simpleName} instance of ${obj.javaClass.name} without initialising a context")))
}
override fun read(kryo: Kryo, input: Input, type: Class<T>): T {
val token = (kryo.readClassAndObject(input) as? SerializationToken) ?: throw KryoException("Non-token read for tokenized type: ${type.name}")
val fromToken = token.fromToken(kryo.serializationContext() ?: throw KryoException("Attempt to read a token for a ${SerializeAsToken::class.simpleName} instance of ${type.name} without initialising a context"))
return type.castIfPossible(fromToken) ?: throw KryoException("Token read ($token) did not return expected tokenized type: ${type.name}")
}
}

View File

@ -0,0 +1,29 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.apache.qpid.proton.amqp.Binary
import org.apache.qpid.proton.codec.Data
import java.lang.reflect.Type
/**
* Serializer / deserializer for native AMQP types (Int, Float, String etc).
*
* [ByteArray] is automatically marshalled to/from the Proton-J wrapper, [Binary].
*/
class AMQPPrimitiveSerializer(clazz: Class<*>) : AMQPSerializer<Any> {
override val typeDescriptor: String = SerializerFactory.primitiveTypeName(clazz)!!
override val type: Type = clazz
// NOOP since this is a primitive type.
override fun writeClassInfo(output: SerializationOutput) {
}
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
if (obj is ByteArray) {
data.putObject(Binary(obj))
} else {
data.putObject(obj)
}
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any = (obj as? Binary)?.array ?: obj
}

View File

@ -0,0 +1,38 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.apache.qpid.proton.codec.Data
import java.lang.reflect.Type
/**
* Implemented to serialize and deserialize different types of objects to/from AMQP.
*/
interface AMQPSerializer<out T> {
/**
* The JVM type this can serialize and deserialize.
*/
val type: Type
/**
* Textual unique representation of the JVM type this represents. Will be encoded into the AMQP stream and
* will appear in the schema.
*
* This should be unique enough that we can use one global cache of [AMQPSerializer]s and use this as the look up key.
*/
val typeDescriptor: String
/**
* Add anything required to the AMQP schema via [SerializationOutput.writeTypeNotations] and any dependent serializers
* via [SerializationOutput.requireSerializer]. e.g. for the elements of an array.
*/
fun writeClassInfo(output: SerializationOutput)
/**
* Write the given object, with declared type, to the output.
*/
fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput)
/**
* Read the given object from the input. The envelope is provided in case the schema is required.
*/
fun readObject(obj: Any, schema: Schema, input: DeserializationInput): T
}

View File

@ -0,0 +1,166 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.apache.qpid.proton.codec.Data
import java.io.NotSerializableException
import java.lang.reflect.Type
/**
* Serialization / deserialization of arrays.
*/
open class ArraySerializer(override val type: Type, factory: SerializerFactory) : AMQPSerializer<Any> {
companion object {
fun make(type: Type, factory: SerializerFactory) = when (type) {
Array<Char>::class.java -> CharArraySerializer(factory)
else -> ArraySerializer(type, factory)
}
}
// because this might be an array of array of primitives (to any recursive depth) and
// because we care that the lowest type is unboxed we can't rely on the inbuilt type
// id to generate it properly (it will always return [[[Ljava.lang.type -> type[][][]
// for example).
//
// We *need* to retain knowledge for AMQP deserialisation weather that lowest primitive
// was boxed or unboxed so just infer it recursively
private fun calcTypeName(type: Type) : String =
if (type.componentType().isArray()) {
val typeName = calcTypeName(type.componentType()); "$typeName[]"
}
else {
val arrayType = if (type.asClass()!!.componentType.isPrimitive) "[p]" else "[]"
"${type.componentType().typeName}$arrayType"
}
override val typeDescriptor by lazy { "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}" }
internal val elementType: Type by lazy { type.componentType() }
internal open val typeName by lazy { calcTypeName(type) }
internal val typeNotation: TypeNotation by lazy {
RestrictedType(typeName, null, emptyList(), "list", Descriptor(typeDescriptor, null), emptyList())
}
override fun writeClassInfo(output: SerializationOutput) {
if (output.writeTypeNotations(typeNotation)) {
output.requireSerializer(elementType)
}
}
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
// Write described
data.withDescribed(typeNotation.descriptor) {
withList {
for (entry in obj as Array<*>) {
output.writeObjectOrNull(entry, this, elementType)
}
}
}
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any {
if (obj is List<*>) {
return obj.map { input.readObjectOrNull(it, schema, elementType) }.toArrayOfType(elementType)
} else throw NotSerializableException("Expected a List but found $obj")
}
open fun <T> List<T>.toArrayOfType(type: Type): Any {
val elementType = type.asClass() ?: throw NotSerializableException("Unexpected array element type $type")
val list = this
return java.lang.reflect.Array.newInstance(elementType, this.size).apply {
(0..lastIndex).forEach { java.lang.reflect.Array.set(this, it, list[it]) }
}
}
}
// Boxed Character arrays required a specialisation to handle the type conversion properly when populating
// the array since Kotlin won't allow an implicit cast from Int (as they're stored as 16bit ints) to Char
class CharArraySerializer(factory: SerializerFactory) : ArraySerializer(Array<Char>::class.java, factory) {
override fun <T> List<T>.toArrayOfType(type: Type): Any {
val elementType = type.asClass() ?: throw NotSerializableException("Unexpected array element type $type")
val list = this
return java.lang.reflect.Array.newInstance(elementType, this.size).apply {
(0..lastIndex).forEach { java.lang.reflect.Array.set(this, it, (list[it] as Int).toChar()) }
}
}
}
// Specialisation of [ArraySerializer] that handles arrays of unboxed java primitive types
abstract class PrimArraySerializer(type: Type, factory: SerializerFactory) : ArraySerializer(type, factory) {
companion object {
// We don't need to handle the unboxed byte type as that is coercible to a byte array, but
// the other 7 primitive types we do
val primTypes: Map<Type, (SerializerFactory) -> PrimArraySerializer> = mapOf(
IntArray::class.java to { f -> PrimIntArraySerializer(f) },
CharArray::class.java to { f -> PrimCharArraySerializer(f) },
BooleanArray::class.java to { f -> PrimBooleanArraySerializer(f) },
FloatArray::class.java to { f -> PrimFloatArraySerializer(f) },
ShortArray::class.java to { f -> PrimShortArraySerializer(f) },
DoubleArray::class.java to { f -> PrimDoubleArraySerializer(f) },
LongArray::class.java to { f -> PrimLongArraySerializer(f) }
// ByteArray::class.java <-> NOT NEEDED HERE (see comment above)
)
fun make(type: Type, factory: SerializerFactory) = primTypes[type]!!(factory)
}
fun localWriteObject(data: Data, func: () -> Unit) {
data.withDescribed(typeNotation.descriptor) { withList { func() } }
}
}
class PrimIntArraySerializer(factory: SerializerFactory) :
PrimArraySerializer(IntArray::class.java, factory) {
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
localWriteObject(data) { (obj as IntArray).forEach { output.writeObjectOrNull(it, data, elementType) } }
}
}
class PrimCharArraySerializer(factory: SerializerFactory) :
PrimArraySerializer(CharArray::class.java, factory) {
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
localWriteObject(data) { (obj as CharArray).forEach { output.writeObjectOrNull(it, data, elementType) } }
}
override fun <T> List<T>.toArrayOfType(type: Type): Any {
val elementType = type.asClass() ?: throw NotSerializableException("Unexpected array element type $type")
val list = this
return java.lang.reflect.Array.newInstance(elementType, this.size).apply {
val array = this
(0..lastIndex).forEach { java.lang.reflect.Array.set(array, it, (list[it] as Int).toChar()) }
}
}
}
class PrimBooleanArraySerializer(factory: SerializerFactory) :
PrimArraySerializer(BooleanArray::class.java, factory) {
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
localWriteObject(data) { (obj as BooleanArray).forEach { output.writeObjectOrNull(it, data, elementType) } }
}
}
class PrimDoubleArraySerializer(factory: SerializerFactory) :
PrimArraySerializer(DoubleArray::class.java, factory) {
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
localWriteObject(data) { (obj as DoubleArray).forEach { output.writeObjectOrNull(it, data, elementType) } }
}
}
class PrimFloatArraySerializer(factory: SerializerFactory) :
PrimArraySerializer(FloatArray::class.java, factory) {
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
localWriteObject(data) { (obj as FloatArray).forEach { output.writeObjectOrNull(it, data, elementType) } }
}
}
class PrimShortArraySerializer(factory: SerializerFactory) :
PrimArraySerializer(ShortArray::class.java, factory) {
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
localWriteObject(data) { (obj as ShortArray).forEach { output.writeObjectOrNull(it, data, elementType) } }
}
}
class PrimLongArraySerializer(factory: SerializerFactory) :
PrimArraySerializer(LongArray::class.java, factory) {
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
localWriteObject(data) { (obj as LongArray).forEach { output.writeObjectOrNull(it, data, elementType) } }
}
}

View File

@ -0,0 +1,58 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.apache.qpid.proton.codec.Data
import java.io.NotSerializableException
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import java.util.*
import kotlin.collections.Collection
import kotlin.collections.LinkedHashSet
import kotlin.collections.Set
/**
* Serialization / deserialization of predefined set of supported [Collection] types covering mostly [List]s and [Set]s.
*/
class CollectionSerializer(val declaredType: ParameterizedType, factory: SerializerFactory) : AMQPSerializer<Any> {
override val type: Type = declaredType as? DeserializedParameterizedType ?: DeserializedParameterizedType.make(declaredType.toString())
override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}"
companion object {
private val supportedTypes: Map<Class<out Collection<*>>, (List<*>) -> Collection<*>> = mapOf(
Collection::class.java to { list -> Collections.unmodifiableCollection(list) },
List::class.java to { list -> Collections.unmodifiableList(list) },
Set::class.java to { list -> Collections.unmodifiableSet(LinkedHashSet(list)) },
SortedSet::class.java to { list -> Collections.unmodifiableSortedSet(TreeSet(list)) },
NavigableSet::class.java to { list -> Collections.unmodifiableNavigableSet(TreeSet(list)) }
)
private fun findConcreteType(clazz: Class<*>): (List<*>) -> Collection<*> {
return supportedTypes[clazz] ?: throw NotSerializableException("Unsupported collection type $clazz.")
}
}
private val concreteBuilder: (List<*>) -> Collection<*> = findConcreteType(declaredType.rawType as Class<*>)
private val typeNotation: TypeNotation = RestrictedType(SerializerFactory.nameForType(declaredType), null, emptyList(), "list", Descriptor(typeDescriptor, null), emptyList())
override fun writeClassInfo(output: SerializationOutput) {
if (output.writeTypeNotations(typeNotation)) {
output.requireSerializer(declaredType.actualTypeArguments[0])
}
}
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
// Write described
data.withDescribed(typeNotation.descriptor) {
withList {
for (entry in obj as Collection<*>) {
output.writeObjectOrNull(entry, this, declaredType.actualTypeArguments[0])
}
}
}
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any {
// TODO: Can we verify the entries in the list?
return concreteBuilder((obj as List<*>).map { input.readObjectOrNull(it, schema, declaredType.actualTypeArguments[0]) })
}
}

View File

@ -0,0 +1,175 @@
package net.corda.nodeapi.internal.serialization.amqp
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.nameForType
import org.apache.qpid.proton.codec.Data
import java.lang.reflect.Type
/**
* Base class for serializers of core platform types that do not conform to the usual serialization rules and thus
* cannot be automatically serialized.
*/
abstract class CustomSerializer<T> : AMQPSerializer<T> {
/**
* This is a collection of custom serializers that this custom serializer depends on. e.g. for proxy objects
* that refer to other custom types etc.
*/
abstract val additionalSerializers: Iterable<CustomSerializer<out Any>>
/**
* This method should return true if the custom serializer can serialize an instance of the class passed as the
* parameter.
*/
abstract fun isSerializerFor(clazz: Class<*>): Boolean
protected abstract val descriptor: Descriptor
/**
* This exists purely for documentation and cross-platform purposes. It is not used by our serialization / deserialization
* code path.
*/
abstract val schemaForDocumentation: Schema
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
data.withDescribed(descriptor) {
@Suppress("UNCHECKED_CAST")
writeDescribedObject(obj as T, data, type, output)
}
}
abstract fun writeDescribedObject(obj: T, data: Data, type: Type, output: SerializationOutput)
/**
* This custom serializer represents a sort of symbolic link from a subclass to a super class, where the super
* class custom serializer is responsible for the "on the wire" format but we want to create a reference to the
* subclass in the schema, so that we can distinguish between subclasses.
*/
// TODO: should this be a custom serializer at all, or should it just be a plain AMQPSerializer?
class SubClass<T>(protected val clazz: Class<*>, protected val superClassSerializer: CustomSerializer<T>) : CustomSerializer<T>() {
override val additionalSerializers: Iterable<CustomSerializer<out Any>> = emptyList()
// TODO: should this be empty or contain the schema of the super?
override val schemaForDocumentation = Schema(emptyList())
override fun isSerializerFor(clazz: Class<*>): Boolean = clazz == this.clazz
override val type: Type get() = clazz
override val typeDescriptor: String = "$DESCRIPTOR_DOMAIN:${fingerprintForDescriptors(superClassSerializer.typeDescriptor, nameForType(clazz))}"
private val typeNotation: TypeNotation = RestrictedType(SerializerFactory.nameForType(clazz), null, emptyList(), SerializerFactory.nameForType(superClassSerializer.type), Descriptor(typeDescriptor, null), emptyList())
override fun writeClassInfo(output: SerializationOutput) {
output.writeTypeNotations(typeNotation)
}
override val descriptor: Descriptor = Descriptor(typeDescriptor)
override fun writeDescribedObject(obj: T, data: Data, type: Type, output: SerializationOutput) {
superClassSerializer.writeDescribedObject(obj, data, type, output)
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): T {
return superClassSerializer.readObject(obj, schema, input)
}
}
/**
* Additional base features for a custom serializer for a particular class, that excludes subclasses.
*/
abstract class Is<T>(protected val clazz: Class<T>) : CustomSerializer<T>() {
override fun isSerializerFor(clazz: Class<*>): Boolean = clazz == this.clazz
override val type: Type get() = clazz
override val typeDescriptor: String = "$DESCRIPTOR_DOMAIN:${nameForType(clazz)}"
override fun writeClassInfo(output: SerializationOutput) {}
override val descriptor: Descriptor = Descriptor(typeDescriptor)
}
/**
* Additional base features for a custom serializer for all implementations of a particular interface or super class.
*/
abstract class Implements<T>(protected val clazz: Class<T>) : CustomSerializer<T>() {
override fun isSerializerFor(clazz: Class<*>): Boolean = this.clazz.isAssignableFrom(clazz)
override val type: Type get() = clazz
override val typeDescriptor: String = "$DESCRIPTOR_DOMAIN:${nameForType(clazz)}"
override fun writeClassInfo(output: SerializationOutput) {}
override val descriptor: Descriptor = Descriptor(typeDescriptor)
}
/**
* Additional base features over and above [Implements] or [Is] custom serializer for when the serialized form should be
* the serialized form of a proxy class, and the object can be re-created from that proxy on deserialization.
*
* The proxy class must use only types which are either native AMQP or other types for which there are pre-registered
* custom serializers.
*/
abstract class Proxy<T, P>(protected val clazz: Class<T>,
protected val proxyClass: Class<P>,
protected val factory: SerializerFactory,
val withInheritance: Boolean = true) : CustomSerializer<T>() {
override fun isSerializerFor(clazz: Class<*>): Boolean = if (withInheritance) this.clazz.isAssignableFrom(clazz) else this.clazz == clazz
override val type: Type get() = clazz
override val typeDescriptor: String = "$DESCRIPTOR_DOMAIN:${nameForType(clazz)}"
override fun writeClassInfo(output: SerializationOutput) {}
override val descriptor: Descriptor = Descriptor(typeDescriptor)
private val proxySerializer: ObjectSerializer by lazy { ObjectSerializer(proxyClass, factory) }
override val schemaForDocumentation: Schema by lazy {
val typeNotations = mutableSetOf<TypeNotation>(CompositeType(nameForType(type), null, emptyList(), descriptor, (proxySerializer.typeNotation as CompositeType).fields))
for (additional in additionalSerializers) {
typeNotations.addAll(additional.schemaForDocumentation.types)
}
Schema(typeNotations.toList())
}
/**
* Implement these two methods.
*/
protected abstract fun toProxy(obj: T): P
protected abstract fun fromProxy(proxy: P): T
override fun writeDescribedObject(obj: T, data: Data, type: Type, output: SerializationOutput) {
val proxy = toProxy(obj)
data.withList {
for (property in proxySerializer.propertySerializers) {
property.writeProperty(proxy, this, output)
}
}
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): T {
@Suppress("UNCHECKED_CAST")
val proxy = proxySerializer.readObject(obj, schema, input) as P
return fromProxy(proxy)
}
}
/**
* A custom serializer where the on-wire representation is a string. For example, a [Currency] might be represented
* as a 3 character currency code, and converted to and from that string. By default, it is assumed that the
* [toString] method will generate the string representation and that there is a constructor that takes such a
* string as an argument to reconstruct.
*
* @param clazz The type to be marshalled
* @param withInheritance Whether subclasses of the class can also be marshalled.
* @param make A lambda for constructing an instance, that defaults to calling a constructor that expects a string.
* @param unmake A lambda that extracts the string value for an instance, that defaults to the [toString] method.
*/
abstract class ToString<T>(clazz: Class<T>, withInheritance: Boolean = false,
private val maker: (String) -> T = clazz.getConstructor(String::class.java).let { `constructor` -> { string -> `constructor`.newInstance(string) } },
private val unmaker: (T) -> String = { obj -> obj.toString() }) : Proxy<T, String>(clazz, String::class.java, /* Unused */ SerializerFactory(), withInheritance) {
override val additionalSerializers: Iterable<CustomSerializer<out Any>> = emptyList()
override val schemaForDocumentation = Schema(listOf(RestrictedType(nameForType(type), "", listOf(nameForType(type)), SerializerFactory.primitiveTypeName(String::class.java)!!, descriptor, emptyList())))
override fun toProxy(obj: T): String = unmaker(obj)
override fun fromProxy(proxy: String): T = maker(proxy)
override fun writeDescribedObject(obj: T, data: Data, type: Type, output: SerializationOutput) {
val proxy = toProxy(obj)
data.putObject(proxy)
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): T {
val proxy = input.readObject(obj, schema, String::class.java) as String
return fromProxy(proxy)
}
}
}

View File

@ -0,0 +1,131 @@
package net.corda.nodeapi.internal.serialization.amqp
import net.corda.core.internal.getStackTraceAsString
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.ByteSequence
import org.apache.qpid.proton.amqp.Binary
import org.apache.qpid.proton.amqp.DescribedType
import org.apache.qpid.proton.amqp.UnsignedByte
import org.apache.qpid.proton.codec.Data
import java.io.NotSerializableException
import java.lang.reflect.Type
import java.nio.ByteBuffer
import java.util.*
data class ObjectAndEnvelope<out T>(val obj: T, val envelope: Envelope)
/**
* Main entry point for deserializing an AMQP encoded object.
*
* @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple
* instances and threads.
*/
class DeserializationInput(internal val serializerFactory: SerializerFactory = SerializerFactory()) {
// TODO: we're not supporting object refs yet
private val objectHistory: MutableList<Any> = ArrayList()
internal companion object {
val BYTES_NEEDED_TO_PEEK: Int = 23
fun peekSize(bytes: ByteArray): Int {
// There's an 8 byte header, and then a 0 byte plus descriptor followed by constructor
val eighth = bytes[8].toInt()
check(eighth == 0x0) { "Expected to find a descriptor in the AMQP stream" }
// We should always have an Envelope, so the descriptor should be a 64-bit long (0x80)
val ninth = UnsignedByte.valueOf(bytes[9]).toInt()
check(ninth == 0x80) { "Expected to find a ulong in the AMQP stream" }
// Skip 8 bytes
val eighteenth = UnsignedByte.valueOf(bytes[18]).toInt()
check(eighteenth == 0xd0 || eighteenth == 0xc0) { "Expected to find a list8 or list32 in the AMQP stream" }
val size = if (eighteenth == 0xc0) {
// Next byte is size
UnsignedByte.valueOf(bytes[19]).toInt() - 3 // Minus three as PEEK_SIZE assumes 4 byte unsigned integer.
} else {
// Next 4 bytes is size
UnsignedByte.valueOf(bytes[19]).toInt().shl(24) + UnsignedByte.valueOf(bytes[20]).toInt().shl(16) + UnsignedByte.valueOf(bytes[21]).toInt().shl(8) + UnsignedByte.valueOf(bytes[22]).toInt()
}
return size + BYTES_NEEDED_TO_PEEK
}
}
@Throws(NotSerializableException::class)
inline fun <reified T : Any> deserialize(bytes: SerializedBytes<T>): T =
deserialize(bytes, T::class.java)
@Throws(NotSerializableException::class)
inline internal fun <reified T : Any> deserializeAndReturnEnvelope(bytes: SerializedBytes<T>): ObjectAndEnvelope<T> =
deserializeAndReturnEnvelope(bytes, T::class.java)
@Throws(NotSerializableException::class)
private fun getEnvelope(bytes: ByteSequence): Envelope {
// Check that the lead bytes match expected header
val headerSize = AmqpHeaderV1_0.size
if (bytes.take(headerSize) != AmqpHeaderV1_0) {
throw NotSerializableException("Serialization header does not match.")
}
val data = Data.Factory.create()
val size = data.decode(ByteBuffer.wrap(bytes.bytes, bytes.offset + headerSize, bytes.size - headerSize))
if (size.toInt() != bytes.size - headerSize) {
throw NotSerializableException("Unexpected size of data")
}
return Envelope.get(data)
}
@Throws(NotSerializableException::class)
private fun <R> des(generator: () -> R): R {
try {
return generator()
} catch(nse: NotSerializableException) {
throw nse
} catch(t: Throwable) {
throw NotSerializableException("Unexpected throwable: ${t.message} ${t.getStackTraceAsString()}")
} finally {
objectHistory.clear()
}
}
/**
* This is the main entry point for deserialization of AMQP payloads, and expects a byte sequence involving a header
* indicating what version of Corda serialization was used, followed by an [Envelope] which carries the object to
* be deserialized and a schema describing the types of the objects.
*/
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>): T {
return des {
val envelope = getEnvelope(bytes)
clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz))
}
}
@Throws(NotSerializableException::class)
internal fun <T : Any> deserializeAndReturnEnvelope(bytes: SerializedBytes<T>, clazz: Class<T>): ObjectAndEnvelope<T> {
return des {
val envelope = getEnvelope(bytes)
// Now pick out the obj and schema from the envelope.
ObjectAndEnvelope(clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz)), envelope)
}
}
internal fun readObjectOrNull(obj: Any?, schema: Schema, type: Type): Any? {
return if (obj == null) null else readObject(obj, schema, type)
}
internal fun readObject(obj: Any, schema: Schema, type: Type): Any {
if (obj is DescribedType) {
// Look up serializer in factory by descriptor
val serializer = serializerFactory.get(obj.descriptor, schema)
if (serializer.type != type && !serializer.type.isSubClassOf(type))
throw NotSerializableException("Described type with descriptor ${obj.descriptor} was " +
"expected to be of type $type but was ${serializer.type}")
return serializer.readObject(obj.described, schema, this)
} else if (obj is Binary) {
return obj.array
} else {
return obj
}
}
}

View File

@ -0,0 +1,18 @@
package net.corda.nodeapi.internal.serialization.amqp
import java.lang.reflect.GenericArrayType
import java.lang.reflect.Type
import java.util.*
/**
* Implementation of [GenericArrayType] that we can actually construct.
*/
class DeserializedGenericArrayType(private val componentType: Type) : GenericArrayType {
override fun getGenericComponentType(): Type = componentType
override fun getTypeName(): String = "${componentType.typeName}[]"
override fun toString(): String = typeName
override fun hashCode(): Int = Objects.hashCode(componentType)
override fun equals(other: Any?): Boolean {
return other is GenericArrayType && (componentType == other.genericComponentType)
}
}

View File

@ -0,0 +1,165 @@
package net.corda.nodeapi.internal.serialization.amqp
import com.google.common.primitives.Primitives
import java.io.NotSerializableException
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import java.lang.reflect.TypeVariable
import java.util.*
/**
* Implementation of [ParameterizedType] that we can actually construct, and a parser from the string representation
* of the JDK implementation which we use as the textual format in the AMQP schema.
*/
class DeserializedParameterizedType(private val rawType: Class<*>, private val params: Array<out Type>, private val ownerType: Type? = null) : ParameterizedType {
init {
if (params.isEmpty()) {
throw NotSerializableException("Must be at least one parameter type in a ParameterizedType")
}
if (params.size != rawType.typeParameters.size) {
throw NotSerializableException("Expected ${rawType.typeParameters.size} for ${rawType.name} but found ${params.size}")
}
// We do not check bounds. Both our use cases (Collection and Map) are not bounded.
if (rawType.typeParameters.any { boundedType(it) }) throw NotSerializableException("Bounded types in ParameterizedTypes not supported, but found a bound in $rawType")
}
private fun boundedType(type: TypeVariable<out Class<out Any>>): Boolean {
return !(type.bounds.size == 1 && type.bounds[0] == Object::class.java)
}
val isFullyWildcarded: Boolean = params.all { it == SerializerFactory.AnyType }
private val _typeName: String = makeTypeName()
private fun makeTypeName(): String {
return if (isFullyWildcarded) {
rawType.name
} else {
val paramsJoined = params.map { it.typeName }.joinToString(", ")
"${rawType.name}<$paramsJoined>"
}
}
companion object {
// Maximum depth/nesting of generics before we suspect some DoS attempt.
const val MAX_DEPTH: Int = 32
fun make(name: String, cl: ClassLoader = DeserializedParameterizedType::class.java.classLoader): Type {
val paramTypes = ArrayList<Type>()
val pos = parseTypeList("$name>", paramTypes, cl)
if (pos <= name.length) {
throw NotSerializableException("Malformed string form of ParameterizedType. Unexpected '>' at character position $pos of $name.")
}
if (paramTypes.size != 1) {
throw NotSerializableException("Expected only one type, but got $paramTypes")
}
return paramTypes[0]
}
private fun parseTypeList(params: String, types: MutableList<Type>, cl: ClassLoader, depth: Int = 0): Int {
var pos = 0
var typeStart = 0
var needAType = true
var skippingWhitespace = false
while (pos < params.length) {
if (params[pos] == '<') {
val typeEnd = pos++
val paramTypes = ArrayList<Type>()
pos = parseTypeParams(params, pos, paramTypes, cl, depth + 1)
types += makeParameterizedType(params.substring(typeStart, typeEnd).trim(), paramTypes, cl)
typeStart = pos
needAType = false
} else if (params[pos] == ',') {
val typeEnd = pos++
val typeName = params.substring(typeStart, typeEnd).trim()
if (!typeName.isEmpty()) {
types += makeType(typeName, cl)
} else if (needAType) {
throw NotSerializableException("Expected a type, not ','")
}
typeStart = pos
needAType = true
} else if (params[pos] == '>') {
val typeEnd = pos++
val typeName = params.substring(typeStart, typeEnd).trim()
if (!typeName.isEmpty()) {
types += makeType(typeName, cl)
} else if (needAType) {
throw NotSerializableException("Expected a type, not '>'")
}
return pos
} else {
// Skip forwards, checking character types
if (pos == typeStart) {
skippingWhitespace = false
if (params[pos].isWhitespace()) {
typeStart = pos++
} else if (!needAType) {
throw NotSerializableException("Not expecting a type")
} else if (params[pos] == '*') {
pos++
} else if (!params[pos].isJavaIdentifierStart()) {
throw NotSerializableException("Invalid character at start of type: ${params[pos]}")
} else {
pos++
}
} else {
if (params[pos].isWhitespace()) {
pos++
skippingWhitespace = true
} else if (!skippingWhitespace && (params[pos] == '.' || params[pos].isJavaIdentifierPart())) {
pos++
} else {
throw NotSerializableException("Invalid character in middle of type: ${params[pos]}")
}
}
}
}
throw NotSerializableException("Missing close generics '>'")
}
private fun makeType(typeName: String, cl: ClassLoader): Type {
// Not generic
return if (typeName == "?") SerializerFactory.AnyType else {
Primitives.wrap(SerializerFactory.primitiveType(typeName) ?: Class.forName(typeName, false, cl))
}
}
private fun makeParameterizedType(rawTypeName: String, args: MutableList<Type>, cl: ClassLoader): Type {
return DeserializedParameterizedType(makeType(rawTypeName, cl) as Class<*>, args.toTypedArray(), null)
}
private fun parseTypeParams(params: String, startPos: Int, paramTypes: MutableList<Type>, cl: ClassLoader, depth: Int): Int {
if (depth == MAX_DEPTH) {
throw NotSerializableException("Maximum depth of nested generics reached: $depth")
}
return startPos + parseTypeList(params.substring(startPos), paramTypes, cl, depth)
}
}
override fun getRawType(): Type = rawType
override fun getOwnerType(): Type? = ownerType
override fun getActualTypeArguments(): Array<out Type> = params
override fun getTypeName(): String = _typeName
override fun toString(): String = _typeName
override fun hashCode(): Int {
return Arrays.hashCode(this.actualTypeArguments) xor Objects.hashCode(this.ownerType) xor Objects.hashCode(this.rawType)
}
override fun equals(other: Any?): Boolean {
if (other is ParameterizedType) {
if (this === other) {
return true
} else {
return this.ownerType == other.ownerType && this.rawType == other.rawType && Arrays.equals(this.actualTypeArguments, other.actualTypeArguments)
}
} else {
return false
}
}
}

View File

@ -0,0 +1,72 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.apache.qpid.proton.codec.Data
import java.io.NotSerializableException
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import java.util.*
import kotlin.collections.Map
import kotlin.collections.iterator
import kotlin.collections.map
/**
* Serialization / deserialization of certain supported [Map] types.
*/
class MapSerializer(val declaredType: ParameterizedType, factory: SerializerFactory) : AMQPSerializer<Any> {
override val type: Type = declaredType as? DeserializedParameterizedType ?: DeserializedParameterizedType.make(declaredType.toString())
override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}"
companion object {
private val supportedTypes: Map<Class<out Map<*, *>>, (Map<*, *>) -> Map<*, *>> = mapOf(
Map::class.java to { map -> Collections.unmodifiableMap(map) },
SortedMap::class.java to { map -> Collections.unmodifiableSortedMap(TreeMap(map)) },
NavigableMap::class.java to { map -> Collections.unmodifiableNavigableMap(TreeMap(map)) }
)
private fun findConcreteType(clazz: Class<*>): (Map<*, *>) -> Map<*, *> {
return supportedTypes[clazz] ?: throw NotSerializableException("Unsupported map type $clazz.")
}
}
private val concreteBuilder: (Map<*, *>) -> Map<*, *> = findConcreteType(declaredType.rawType as Class<*>)
private val typeNotation: TypeNotation = RestrictedType(SerializerFactory.nameForType(declaredType), null, emptyList(), "map", Descriptor(typeDescriptor, null), emptyList())
override fun writeClassInfo(output: SerializationOutput) {
if (output.writeTypeNotations(typeNotation)) {
output.requireSerializer(declaredType.actualTypeArguments[0])
output.requireSerializer(declaredType.actualTypeArguments[1])
}
}
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
obj.javaClass.checkNotUnorderedHashMap()
// Write described
data.withDescribed(typeNotation.descriptor) {
// Write map
data.putMap()
data.enter()
for ((key, value) in obj as Map<*, *>) {
output.writeObjectOrNull(key, data, declaredType.actualTypeArguments[0])
output.writeObjectOrNull(value, data, declaredType.actualTypeArguments[1])
}
data.exit() // exit map
}
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any {
// TODO: General generics question. Do we need to validate that entries in Maps and Collections match the generic type? Is it a security hole?
val entries: Iterable<Pair<Any?, Any?>> = (obj as Map<*, *>).map { readEntry(schema, input, it) }
return concreteBuilder(entries.toMap())
}
private fun readEntry(schema: Schema, input: DeserializationInput, entry: Map.Entry<Any?, Any?>) =
input.readObjectOrNull(entry.key, schema, declaredType.actualTypeArguments[0]) to
input.readObjectOrNull(entry.value, schema, declaredType.actualTypeArguments[1])
}
internal fun Class<*>.checkNotUnorderedHashMap() {
if (HashMap::class.java.isAssignableFrom(this) && !LinkedHashMap::class.java.isAssignableFrom(this)) {
throw IllegalArgumentException("Map type $this is unstable under iteration. Suggested fix: use java.util.LinkedHashMap instead.")
}
}

View File

@ -0,0 +1,81 @@
package net.corda.nodeapi.internal.serialization.amqp
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.nameForType
import org.apache.qpid.proton.amqp.UnsignedInteger
import org.apache.qpid.proton.codec.Data
import java.io.NotSerializableException
import java.lang.reflect.Constructor
import java.lang.reflect.Type
import kotlin.reflect.jvm.javaConstructor
/**
* Responsible for serializing and deserializing a regular object instance via a series of properties (matched with a constructor).
*/
class ObjectSerializer(val clazz: Type, factory: SerializerFactory) : AMQPSerializer<Any> {
override val type: Type get() = clazz
private val javaConstructor: Constructor<Any>?
internal val propertySerializers: Collection<PropertySerializer>
init {
val kotlinConstructor = constructorForDeserialization(clazz)
javaConstructor = kotlinConstructor?.javaConstructor
propertySerializers = propertiesForSerialization(kotlinConstructor, clazz, factory)
}
private val typeName = nameForType(clazz)
override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}"
private val interfaces = interfacesForSerialization(clazz) // TODO maybe this proves too much and we need annotations to restrict.
internal val typeNotation: TypeNotation = CompositeType(typeName, null, generateProvides(), Descriptor(typeDescriptor, null), generateFields())
override fun writeClassInfo(output: SerializationOutput) {
if (output.writeTypeNotations(typeNotation)) {
for (iface in interfaces) {
output.requireSerializer(iface)
}
for (property in propertySerializers) {
property.writeClassInfo(output)
}
}
}
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
// Write described
data.withDescribed(typeNotation.descriptor) {
// Write list
withList {
for (property in propertySerializers) {
property.writeProperty(obj, this, output)
}
}
}
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any {
if (obj is UnsignedInteger) {
// TODO: Object refs
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
} else if (obj is List<*>) {
if (obj.size > propertySerializers.size) throw NotSerializableException("Too many properties in described type $typeName")
val params = obj.zip(propertySerializers).map { it.second.readProperty(it.first, schema, input) }
return construct(params)
} else throw NotSerializableException("Body of described type is unexpected $obj")
}
private fun generateFields(): List<Field> {
return propertySerializers.map { Field(it.name, it.type, it.requires, it.default, null, it.mandatory, false) }
}
private fun generateProvides(): List<String> {
return interfaces.map { nameForType(it) }
}
fun construct(properties: List<Any?>): Any {
if (javaConstructor == null) {
throw NotSerializableException("Attempt to deserialize an interface: $clazz. Serialized form is invalid.")
}
return javaConstructor.newInstance(*properties.toTypedArray())
}
}

View File

@ -0,0 +1,129 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.apache.qpid.proton.amqp.Binary
import org.apache.qpid.proton.codec.Data
import java.lang.reflect.Method
import java.lang.reflect.Type
import kotlin.reflect.full.memberProperties
import kotlin.reflect.jvm.javaGetter
/**
* Base class for serialization of a property of an object.
*/
sealed class PropertySerializer(val name: String, val readMethod: Method, val resolvedType: Type) {
abstract fun writeClassInfo(output: SerializationOutput)
abstract fun writeProperty(obj: Any?, data: Data, output: SerializationOutput)
abstract fun readProperty(obj: Any?, schema: Schema, input: DeserializationInput): Any?
val type: String = generateType()
val requires: List<String> = generateRequires()
val default: String? = generateDefault()
val mandatory: Boolean = generateMandatory()
private val isInterface: Boolean get() = resolvedType.asClass()?.isInterface ?: false
private val isJVMPrimitive: Boolean get() = resolvedType.asClass()?.isPrimitive ?: false
private fun generateType(): String {
return if (isInterface || resolvedType == Any::class.java) "*" else SerializerFactory.nameForType(resolvedType)
}
private fun generateRequires(): List<String> {
return if (isInterface) listOf(SerializerFactory.nameForType(resolvedType)) else emptyList()
}
private fun generateDefault(): String? {
if (isJVMPrimitive) {
return when (resolvedType) {
java.lang.Boolean.TYPE -> "false"
java.lang.Character.TYPE -> "&#0"
else -> "0"
}
} else {
return null
}
}
private fun generateMandatory(): Boolean {
return isJVMPrimitive || !readMethod.returnsNullable()
}
private fun Method.returnsNullable(): Boolean {
val returnTypeString = this.declaringClass.kotlin.memberProperties.firstOrNull { it.javaGetter == this }?.returnType?.toString() ?: "?"
return returnTypeString.endsWith('?') || returnTypeString.endsWith('!')
}
companion object {
fun make(name: String, readMethod: Method, resolvedType: Type, factory: SerializerFactory): PropertySerializer {
if (SerializerFactory.isPrimitive(resolvedType)) {
return when(resolvedType) {
Char::class.java, Character::class.java -> AMQPCharPropertySerializer(name, readMethod)
else -> AMQPPrimitivePropertySerializer(name, readMethod, resolvedType)
}
} else {
return DescribedTypePropertySerializer(name, readMethod, resolvedType) { factory.get(null, resolvedType) }
}
}
}
/**
* A property serializer for a complex type (another object).
*/
class DescribedTypePropertySerializer(name: String, readMethod: Method, resolvedType: Type, private val lazyTypeSerializer: () -> AMQPSerializer<*>) : PropertySerializer(name, readMethod, resolvedType) {
// This is lazy so we don't get an infinite loop when a method returns an instance of the class.
private val typeSerializer: AMQPSerializer<*> by lazy { lazyTypeSerializer() }
override fun writeClassInfo(output: SerializationOutput) {
if (resolvedType != Any::class.java) {
typeSerializer.writeClassInfo(output)
}
}
override fun readProperty(obj: Any?, schema: Schema, input: DeserializationInput): Any? {
return input.readObjectOrNull(obj, schema, resolvedType)
}
override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput) {
output.writeObjectOrNull(readMethod.invoke(obj), data, resolvedType)
}
}
/**
* A property serializer for most AMQP primitive type (Int, String, etc).
*/
class AMQPPrimitivePropertySerializer(name: String, readMethod: Method, resolvedType: Type) : PropertySerializer(name, readMethod, resolvedType) {
override fun writeClassInfo(output: SerializationOutput) {}
override fun readProperty(obj: Any?, schema: Schema, input: DeserializationInput): Any? {
return if (obj is Binary) obj.array else obj
}
override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput) {
val value = readMethod.invoke(obj)
if (value is ByteArray) {
data.putObject(Binary(value))
} else {
data.putObject(value)
}
}
}
/**
* A property serializer for the AMQP char type, needed as a specialisation as the underlying
* value of the character is stored in numeric UTF-16 form and on deserialisation requires explicit
* casting back to a char otherwise it's treated as an Integer and a TypeMismatch occurs
*/
class AMQPCharPropertySerializer(name: String, readMethod: Method) :
PropertySerializer(name, readMethod, Character::class.java) {
override fun writeClassInfo(output: SerializationOutput) {}
override fun readProperty(obj: Any?, schema: Schema, input: DeserializationInput): Any? {
return if(obj == null) null else (obj as Short).toChar()
}
override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput) {
val input = readMethod.invoke(obj)
if (input != null) data.putShort((input as Char).toShort()) else data.putNull()
}
}
}

View File

@ -0,0 +1,408 @@
package net.corda.nodeapi.internal.serialization.amqp
import com.google.common.hash.Hasher
import com.google.common.hash.Hashing
import net.corda.core.crypto.toBase64
import net.corda.core.utilities.OpaqueBytes
import org.apache.qpid.proton.amqp.DescribedType
import org.apache.qpid.proton.amqp.UnsignedLong
import org.apache.qpid.proton.codec.Data
import org.apache.qpid.proton.codec.DescribedTypeConstructor
import java.io.NotSerializableException
import java.lang.reflect.GenericArrayType
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import java.lang.reflect.TypeVariable
import java.util.*
import net.corda.nodeapi.internal.serialization.carpenter.Field as CarpenterField
import net.corda.nodeapi.internal.serialization.carpenter.Schema as CarpenterSchema
// TODO: get an assigned number as per AMQP spec
val DESCRIPTOR_TOP_32BITS: Long = 0xc0da0000
val DESCRIPTOR_DOMAIN: String = "net.corda"
// "corda" + majorVersionByte + minorVersionMSB + minorVersionLSB
val AmqpHeaderV1_0: OpaqueBytes = OpaqueBytes("corda\u0001\u0000\u0000".toByteArray())
/**
* This class wraps all serialized data, so that the schema can be carried along with it. We will provide various internal utilities
* to decompose and recompose with/without schema etc so that e.g. we can store objects with a (relationally) normalised out schema to
* avoid excessive duplication.
*/
// TODO: make the schema parsing lazy since mostly schemas will have been seen before and we only need it if we don't recognise a type descriptor.
data class Envelope(val obj: Any?, val schema: Schema) : DescribedType {
companion object : DescribedTypeConstructor<Envelope> {
val DESCRIPTOR = UnsignedLong(1L or DESCRIPTOR_TOP_32BITS)
val DESCRIPTOR_OBJECT = Descriptor(null, DESCRIPTOR)
fun get(data: Data): Envelope {
val describedType = data.`object` as DescribedType
if (describedType.descriptor != DESCRIPTOR) {
throw NotSerializableException("Unexpected descriptor ${describedType.descriptor}.")
}
val list = describedType.described as List<*>
return newInstance(listOf(list[0], Schema.get(list[1]!!)))
}
override fun getTypeClass(): Class<*> = Envelope::class.java
override fun newInstance(described: Any?): Envelope {
val list = described as? List<*> ?: throw IllegalStateException("Was expecting a list")
return Envelope(list[0], list[1] as Schema)
}
}
override fun getDescriptor(): Any = DESCRIPTOR
override fun getDescribed(): Any = listOf(obj, schema)
}
/**
* This and the classes below are OO representations of the AMQP XML schema described in the specification. Their
* [toString] representations generate the associated XML form.
*/
data class Schema(val types: List<TypeNotation>) : DescribedType {
companion object : DescribedTypeConstructor<Schema> {
val DESCRIPTOR = UnsignedLong(2L or DESCRIPTOR_TOP_32BITS)
fun get(obj: Any): Schema {
val describedType = obj as DescribedType
if (describedType.descriptor != DESCRIPTOR) {
throw NotSerializableException("Unexpected descriptor ${describedType.descriptor}.")
}
val list = describedType.described as List<*>
return newInstance(listOf((list[0] as List<*>).map { TypeNotation.get(it!!) }))
}
override fun getTypeClass(): Class<*> = Schema::class.java
override fun newInstance(described: Any?): Schema {
val list = described as? List<*> ?: throw IllegalStateException("Was expecting a list")
@Suppress("UNCHECKED_CAST")
return Schema(list[0] as List<TypeNotation>)
}
}
override fun getDescriptor(): Any = DESCRIPTOR
override fun getDescribed(): Any = listOf(types)
override fun toString(): String = types.joinToString("\n")
}
data class Descriptor(val name: String?, val code: UnsignedLong? = null) : DescribedType {
companion object : DescribedTypeConstructor<Descriptor> {
val DESCRIPTOR = UnsignedLong(3L or DESCRIPTOR_TOP_32BITS)
fun get(obj: Any): Descriptor {
val describedType = obj as DescribedType
if (describedType.descriptor != DESCRIPTOR) {
throw NotSerializableException("Unexpected descriptor ${describedType.descriptor}.")
}
return newInstance(describedType.described)
}
override fun getTypeClass(): Class<*> = Descriptor::class.java
override fun newInstance(described: Any?): Descriptor {
val list = described as? List<*> ?: throw IllegalStateException("Was expecting a list")
return Descriptor(list[0] as? String, list[1] as? UnsignedLong)
}
}
override fun getDescriptor(): Any = DESCRIPTOR
override fun getDescribed(): Any = listOf(name, code)
override fun toString(): String {
val sb = StringBuilder("<descriptor")
if (name != null) {
sb.append(" name=\"$name\"")
}
if (code != null) {
val code = String.format("0x%08x:0x%08x", code.toLong().shr(32), code.toLong().and(0xffff))
sb.append(" code=\"$code\"")
}
sb.append("/>")
return sb.toString()
}
}
data class Field(val name: String, val type: String, val requires: List<String>, val default: String?, val label: String?, val mandatory: Boolean, val multiple: Boolean) : DescribedType {
companion object : DescribedTypeConstructor<Field> {
val DESCRIPTOR = UnsignedLong(4L or DESCRIPTOR_TOP_32BITS)
fun get(obj: Any): Field {
val describedType = obj as DescribedType
if (describedType.descriptor != DESCRIPTOR) {
throw NotSerializableException("Unexpected descriptor ${describedType.descriptor}.")
}
return newInstance(describedType.described)
}
override fun getTypeClass(): Class<*> = Field::class.java
override fun newInstance(described: Any?): Field {
val list = described as? List<*> ?: throw IllegalStateException("Was expecting a list")
@Suppress("UNCHECKED_CAST")
return Field(list[0] as String, list[1] as String, list[2] as List<String>, list[3] as? String, list[4] as? String, list[5] as Boolean, list[6] as Boolean)
}
}
override fun getDescriptor(): Any = DESCRIPTOR
override fun getDescribed(): Any = listOf(name, type, requires, default, label, mandatory, multiple)
override fun toString(): String {
val sb = StringBuilder("<field name=\"$name\" type=\"$type\" mandatory=\"$mandatory\" multiple=\"$multiple\"")
if (requires.isNotEmpty()) {
sb.append(" requires=\"")
sb.append(requires.joinToString(","))
sb.append("\"")
}
if (default != null) {
sb.append(" default=\"$default\"")
}
if (!label.isNullOrBlank()) {
sb.append(" label=\"$label\"")
}
sb.append("/>")
return sb.toString()
}
}
sealed class TypeNotation : DescribedType {
companion object {
fun get(obj: Any): TypeNotation {
val describedType = obj as DescribedType
if (describedType.descriptor == CompositeType.DESCRIPTOR) {
return CompositeType.get(describedType)
} else if (describedType.descriptor == RestrictedType.DESCRIPTOR) {
return RestrictedType.get(describedType)
} else {
throw NotSerializableException("Unexpected descriptor ${describedType.descriptor}.")
}
}
}
abstract val name: String
abstract val label: String?
abstract val provides: List<String>
abstract val descriptor: Descriptor
}
data class CompositeType(override val name: String, override val label: String?, override val provides: List<String>, override val descriptor: Descriptor, val fields: List<Field>) : TypeNotation() {
companion object : DescribedTypeConstructor<CompositeType> {
val DESCRIPTOR = UnsignedLong(5L or DESCRIPTOR_TOP_32BITS)
fun get(describedType: DescribedType): CompositeType {
if (describedType.descriptor != DESCRIPTOR) {
throw NotSerializableException("Unexpected descriptor ${describedType.descriptor}.")
}
val list = describedType.described as List<*>
return newInstance(listOf(list[0], list[1], list[2], Descriptor.get(list[3]!!), (list[4] as List<*>).map { Field.get(it!!) }))
}
override fun getTypeClass(): Class<*> = CompositeType::class.java
override fun newInstance(described: Any?): CompositeType {
val list = described as? List<*> ?: throw IllegalStateException("Was expecting a list")
@Suppress("UNCHECKED_CAST")
return CompositeType(list[0] as String, list[1] as? String, list[2] as List<String>, list[3] as Descriptor, list[4] as List<Field>)
}
}
override fun getDescriptor(): Any = DESCRIPTOR
override fun getDescribed(): Any = listOf(name, label, provides, descriptor, fields)
override fun toString(): String {
val sb = StringBuilder("<type class=\"composite\" name=\"$name\"")
if (!label.isNullOrBlank()) {
sb.append(" label=\"$label\"")
}
if (provides.isNotEmpty()) {
sb.append(" provides=\"")
sb.append(provides.joinToString(","))
sb.append("\"")
}
sb.append(">\n")
sb.append(" $descriptor\n")
for (field in fields) {
sb.append(" $field\n")
}
sb.append("</type>")
return sb.toString()
}
}
data class RestrictedType(override val name: String, override val label: String?, override val provides: List<String>, val source: String, override val descriptor: Descriptor, val choices: List<Choice>) : TypeNotation() {
companion object : DescribedTypeConstructor<RestrictedType> {
val DESCRIPTOR = UnsignedLong(6L or DESCRIPTOR_TOP_32BITS)
fun get(describedType: DescribedType): RestrictedType {
if (describedType.descriptor != DESCRIPTOR) {
throw NotSerializableException("Unexpected descriptor ${describedType.descriptor}.")
}
val list = describedType.described as List<*>
return newInstance(listOf(list[0], list[1], list[2], list[3], Descriptor.get(list[4]!!), (list[5] as List<*>).map { Choice.get(it!!) }))
}
override fun getTypeClass(): Class<*> = RestrictedType::class.java
override fun newInstance(described: Any?): RestrictedType {
val list = described as? List<*> ?: throw IllegalStateException("Was expecting a list")
@Suppress("UNCHECKED_CAST")
return RestrictedType(list[0] as String, list[1] as? String, list[2] as List<String>, list[3] as String, list[4] as Descriptor, list[5] as List<Choice>)
}
}
override fun getDescriptor(): Any = DESCRIPTOR
override fun getDescribed(): Any = listOf(name, label, provides, source, descriptor, choices)
override fun toString(): String {
val sb = StringBuilder("<type class=\"restricted\" name=\"$name\"")
if (!label.isNullOrBlank()) {
sb.append(" label=\"$label\"")
}
sb.append(" source=\"$source\"")
if (provides.isNotEmpty()) {
sb.append(" provides=\"")
sb.append(provides.joinToString(","))
sb.append("\"")
}
sb.append(">\n")
sb.append(" $descriptor\n")
sb.append("</type>")
return sb.toString()
}
}
data class Choice(val name: String, val value: String) : DescribedType {
companion object : DescribedTypeConstructor<Choice> {
val DESCRIPTOR = UnsignedLong(7L or DESCRIPTOR_TOP_32BITS)
fun get(obj: Any): Choice {
val describedType = obj as DescribedType
if (describedType.descriptor != DESCRIPTOR) {
throw NotSerializableException("Unexpected descriptor ${describedType.descriptor}.")
}
return newInstance(describedType.described)
}
override fun getTypeClass(): Class<*> = Choice::class.java
override fun newInstance(described: Any?): Choice {
val list = described as? List<*> ?: throw IllegalStateException("Was expecting a list")
return Choice(list[0] as String, list[1] as String)
}
}
override fun getDescriptor(): Any = DESCRIPTOR
override fun getDescribed(): Any = listOf(name, value)
override fun toString(): String {
return "<choice name=\"$name\" value=\"$value\"/>"
}
}
private val ARRAY_HASH: String = "Array = true"
private val ALREADY_SEEN_HASH: String = "Already seen = true"
private val NULLABLE_HASH: String = "Nullable = true"
private val NOT_NULLABLE_HASH: String = "Nullable = false"
private val ANY_TYPE_HASH: String = "Any type = true"
private val TYPE_VARIABLE_HASH: String = "Type variable = true"
/**
* The method generates a fingerprint for a given JVM [Type] that should be unique to the schema representation.
* Thus it only takes into account properties and types and only supports the same object graph subset as the overall
* serialization code.
*
* The idea being that even for two classes that share the same name but differ in a minor way, the fingerprint will be
* different.
*/
// TODO: write tests
internal fun fingerprintForType(type: Type, factory: SerializerFactory): String {
return fingerprintForType(type, null, HashSet(), Hashing.murmur3_128().newHasher(), factory).hash().asBytes().toBase64()
}
internal fun fingerprintForDescriptors(vararg typeDescriptors: String): String {
val hasher = Hashing.murmur3_128().newHasher()
for (typeDescriptor in typeDescriptors) {
hasher.putUnencodedChars(typeDescriptor)
}
return hasher.hash().asBytes().toBase64()
}
// This method concatentates various elements of the types recursively as unencoded strings into the hasher, effectively
// creating a unique string for a type which we then hash in the calling function above.
private fun fingerprintForType(type: Type, contextType: Type?, alreadySeen: MutableSet<Type>, hasher: Hasher, factory: SerializerFactory): Hasher {
return if (type in alreadySeen) {
hasher.putUnencodedChars(ALREADY_SEEN_HASH)
} else {
alreadySeen += type
try {
if (type is SerializerFactory.AnyType) {
hasher.putUnencodedChars(ANY_TYPE_HASH)
} else if (type is Class<*>) {
if (type.isArray) {
fingerprintForType(type.componentType, contextType, alreadySeen, hasher, factory).putUnencodedChars(ARRAY_HASH)
} else if (SerializerFactory.isPrimitive(type)) {
hasher.putUnencodedChars(type.name)
} else if (isCollectionOrMap(type)) {
hasher.putUnencodedChars(type.name)
} else {
// Need to check if a custom serializer is applicable
val customSerializer = factory.findCustomSerializer(type, type)
if (customSerializer == null) {
if (type.kotlin.objectInstance != null) {
// TODO: name collision is too likely for kotlin objects, we need to introduce some reference
// to the CorDapp but maybe reference to the JAR in the short term.
hasher.putUnencodedChars(type.name)
} else {
fingerprintForObject(type, contextType, alreadySeen, hasher, factory)
}
} else {
hasher.putUnencodedChars(customSerializer.typeDescriptor)
}
}
} else if (type is ParameterizedType) {
// Hash the rawType + params
val clazz = type.rawType as Class<*>
val startingHash = if (isCollectionOrMap(clazz)) {
hasher.putUnencodedChars(clazz.name)
} else {
fingerprintForObject(type, type, alreadySeen, hasher, factory)
}
// ... and concatentate the type data for each parameter type.
type.actualTypeArguments.fold(startingHash) { orig, paramType -> fingerprintForType(paramType, type, alreadySeen, orig, factory) }
} else if (type is GenericArrayType) {
// Hash the element type + some array hash
fingerprintForType(type.genericComponentType, contextType, alreadySeen, hasher, factory).putUnencodedChars(ARRAY_HASH)
} else if (type is TypeVariable<*>) {
// TODO: include bounds
hasher.putUnencodedChars(type.name).putUnencodedChars(TYPE_VARIABLE_HASH)
} else {
throw NotSerializableException("Don't know how to hash")
}
} catch(e: NotSerializableException) {
throw NotSerializableException("${e.message} -> $type")
}
}
}
private fun isCollectionOrMap(type: Class<*>) = Collection::class.java.isAssignableFrom(type) || Map::class.java.isAssignableFrom(type)
private fun fingerprintForObject(type: Type, contextType: Type?, alreadySeen: MutableSet<Type>, hasher: Hasher, factory: SerializerFactory): Hasher {
// Hash the class + properties + interfaces
val name = type.asClass()?.name ?: throw NotSerializableException("Expected only Class or ParameterizedType but found $type")
propertiesForSerialization(constructorForDeserialization(type), contextType ?: type, factory).fold(hasher.putUnencodedChars(name)) { orig, prop ->
fingerprintForType(prop.resolvedType, type, alreadySeen, orig, factory).putUnencodedChars(prop.name).putUnencodedChars(if (prop.mandatory) NOT_NULLABLE_HASH else NULLABLE_HASH)
}
interfacesForSerialization(type).map { fingerprintForType(it, type, alreadySeen, hasher, factory) }
return hasher
}

View File

@ -0,0 +1,202 @@
package net.corda.nodeapi.internal.serialization.amqp
import com.google.common.reflect.TypeToken
import org.apache.qpid.proton.codec.Data
import java.beans.Introspector
import java.io.NotSerializableException
import java.lang.reflect.*
import java.util.*
import kotlin.reflect.KClass
import kotlin.reflect.KFunction
import kotlin.reflect.KParameter
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.full.primaryConstructor
import kotlin.reflect.jvm.javaType
/**
* Annotation indicating a constructor to be used to reconstruct instances of a class during deserialization.
*/
@Target(AnnotationTarget.CONSTRUCTOR)
@Retention(AnnotationRetention.RUNTIME)
annotation class ConstructorForDeserialization
/**
* Code for finding the constructor we will use for deserialization.
*
* If there's only one constructor, it selects that. If there are two and one is the default, it selects the other.
* Otherwise it starts with the primary constructor in kotlin, if there is one, and then will override this with any that is
* annotated with [@CordaConstructor]. It will report an error if more than one constructor is annotated.
*/
internal fun constructorForDeserialization(type: Type): KFunction<Any>? {
val clazz: Class<*> = type.asClass()!!
if (isConcrete(clazz)) {
var preferredCandidate: KFunction<Any>? = clazz.kotlin.primaryConstructor
var annotatedCount = 0
val kotlinConstructors = clazz.kotlin.constructors
val hasDefault = kotlinConstructors.any { it.parameters.isEmpty() }
for (kotlinConstructor in kotlinConstructors) {
if (preferredCandidate == null && kotlinConstructors.size == 1 && !hasDefault) {
preferredCandidate = kotlinConstructor
} else if (preferredCandidate == null && kotlinConstructors.size == 2 && hasDefault && kotlinConstructor.parameters.isNotEmpty()) {
preferredCandidate = kotlinConstructor
} else if (kotlinConstructor.findAnnotation<ConstructorForDeserialization>() != null) {
if (annotatedCount++ > 0) {
throw NotSerializableException("More than one constructor for $clazz is annotated with @CordaConstructor.")
}
preferredCandidate = kotlinConstructor
}
}
return preferredCandidate ?: throw NotSerializableException("No constructor for deserialization found for $clazz.")
} else {
return null
}
}
/**
* Identifies the properties to be used during serialization by attempting to find those that match the parameters to the
* deserialization constructor, if the class is concrete. If it is abstract, or an interface, then use all the properties.
*
* Note, you will need any Java classes to be compiled with the `-parameters` option to ensure constructor parameters have
* names accessible via reflection.
*/
internal fun <T : Any> propertiesForSerialization(kotlinConstructor: KFunction<T>?, type: Type, factory: SerializerFactory): Collection<PropertySerializer> {
val clazz = type.asClass()!!
return if (kotlinConstructor != null) propertiesForSerializationFromConstructor(kotlinConstructor, type, factory) else propertiesForSerializationFromAbstract(clazz, type, factory)
}
private fun isConcrete(clazz: Class<*>): Boolean = !(clazz.isInterface || Modifier.isAbstract(clazz.modifiers))
private fun <T : Any> propertiesForSerializationFromConstructor(kotlinConstructor: KFunction<T>, type: Type, factory: SerializerFactory): Collection<PropertySerializer> {
val clazz = (kotlinConstructor.returnType.classifier as KClass<*>).javaObjectType
// Kotlin reflection doesn't work with Java getters the way you might expect, so we drop back to good ol' beans.
val properties = Introspector.getBeanInfo(clazz).propertyDescriptors.filter { it.name != "class" }.groupBy { it.name }.mapValues { it.value[0] }
val rc: MutableList<PropertySerializer> = ArrayList(kotlinConstructor.parameters.size)
for (param in kotlinConstructor.parameters) {
val name = param.name ?: throw NotSerializableException("Constructor parameter of $clazz has no name.")
val matchingProperty = properties[name] ?: throw NotSerializableException("No property matching constructor parameter named $name of $clazz." +
" If using Java, check that you have the -parameters option specified in the Java compiler.")
// Check that the method has a getter in java.
val getter = matchingProperty.readMethod ?: throw NotSerializableException("Property has no getter method for $name of $clazz." +
" If using Java and the parameter name looks anonymous, check that you have the -parameters option specified in the Java compiler.")
val returnType = resolveTypeVariables(getter.genericReturnType, type)
if (constructorParamTakesReturnTypeOfGetter(getter, param)) {
rc += PropertySerializer.make(name, getter, returnType, factory)
} else {
throw NotSerializableException("Property type $returnType for $name of $clazz differs from constructor parameter type ${param.type.javaType}")
}
}
return rc
}
private fun constructorParamTakesReturnTypeOfGetter(getter: Method, param: KParameter): Boolean = TypeToken.of(param.type.javaType).isSupertypeOf(getter.genericReturnType)
private fun propertiesForSerializationFromAbstract(clazz: Class<*>, type: Type, factory: SerializerFactory): Collection<PropertySerializer> {
// Kotlin reflection doesn't work with Java getters the way you might expect, so we drop back to good ol' beans.
val properties = Introspector.getBeanInfo(clazz).propertyDescriptors.filter { it.name != "class" }.sortedBy { it.name }
val rc: MutableList<PropertySerializer> = ArrayList(properties.size)
for (property in properties) {
// Check that the method has a getter in java.
val getter = property.readMethod ?: throw NotSerializableException("Property has no getter method for ${property.name} of $clazz.")
val returnType = resolveTypeVariables(getter.genericReturnType, type)
rc += PropertySerializer.make(property.name, getter, returnType, factory)
}
return rc
}
internal fun interfacesForSerialization(type: Type): List<Type> {
val interfaces = LinkedHashSet<Type>()
exploreType(type, interfaces)
return interfaces.toList()
}
private fun exploreType(type: Type?, interfaces: MutableSet<Type>) {
val clazz = type?.asClass()
if (clazz != null) {
if (clazz.isInterface) interfaces += type
for (newInterface in clazz.genericInterfaces) {
if (newInterface !in interfaces) {
exploreType(resolveTypeVariables(newInterface, type), interfaces)
}
}
val superClass = clazz.genericSuperclass ?: return
exploreType(resolveTypeVariables(superClass, type), interfaces)
}
}
/**
* Extension helper for writing described objects.
*/
fun Data.withDescribed(descriptor: Descriptor, block: Data.() -> Unit) {
// Write described
putDescribed()
enter()
// Write descriptor
putObject(descriptor.code ?: descriptor.name)
block()
exit() // exit described
}
/**
* Extension helper for writing lists.
*/
fun Data.withList(block: Data.() -> Unit) {
// Write list
putList()
enter()
block()
exit() // exit list
}
private fun resolveTypeVariables(actualType: Type, contextType: Type?): Type {
val resolvedType = if (contextType != null) TypeToken.of(contextType).resolveType(actualType).type else actualType
// TODO: surely we check it is concrete at this point with no TypeVariables
return if (resolvedType is TypeVariable<*>) {
val bounds = resolvedType.bounds
return if (bounds.isEmpty()) SerializerFactory.AnyType else if (bounds.size == 1) resolveTypeVariables(bounds[0], contextType) else throw NotSerializableException("Got bounded type $actualType but only support single bound.")
} else {
resolvedType
}
}
internal fun Type.asClass(): Class<*>? {
return if (this is Class<*>) {
this
} else if (this is ParameterizedType) {
this.rawType.asClass()
} else if (this is GenericArrayType) {
this.genericComponentType.asClass()?.arrayClass()
} else null
}
internal fun Type.asArray(): Type? {
return if (this is Class<*>) {
this.arrayClass()
} else if (this is ParameterizedType) {
DeserializedGenericArrayType(this)
} else null
}
internal fun Class<*>.arrayClass(): Class<*> = java.lang.reflect.Array.newInstance(this, 0).javaClass
internal fun Type.isArray(): Boolean = (this is Class<*> && this.isArray) || (this is GenericArrayType)
internal fun Type.componentType(): Type {
check(this.isArray()) { "$this is not an array type." }
return (this as? Class<*>)?.componentType ?: (this as GenericArrayType).genericComponentType
}
internal fun Class<*>.asParameterizedType(): ParameterizedType {
return DeserializedParameterizedType(this, this.typeParameters)
}
internal fun Type.asParameterizedType(): ParameterizedType {
return when (this) {
is Class<*> -> this.asParameterizedType()
is ParameterizedType -> this
else -> throw NotSerializableException("Don't know how to convert to ParameterizedType")
}
}
internal fun Type.isSubClassOf(type: Type): Boolean {
return TypeToken.of(this).isSubtypeOf(type)
}

View File

@ -0,0 +1,91 @@
package net.corda.nodeapi.internal.serialization.amqp
import net.corda.core.serialization.SerializedBytes
import org.apache.qpid.proton.codec.Data
import java.io.NotSerializableException
import java.lang.reflect.Type
import java.nio.ByteBuffer
import java.util.*
import kotlin.collections.LinkedHashSet
/**
* Main entry point for serializing an object to AMQP.
*
* @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple
* instances and threads.
*/
open class SerializationOutput(internal val serializerFactory: SerializerFactory = SerializerFactory()) {
// TODO: we're not supporting object refs yet
private val objectHistory: MutableMap<Any, Int> = IdentityHashMap()
private val serializerHistory: MutableSet<AMQPSerializer<*>> = LinkedHashSet()
private val schemaHistory: MutableSet<TypeNotation> = LinkedHashSet()
/**
* Serialize the given object to AMQP, wrapped in our [Envelope] wrapper which carries an AMQP 1.0 schema, and prefixed
* with a header to indicate that this is serialized with AMQP and not [Kryo], and what version of the Corda implementation
* of AMQP serialization contructed the serialized form.
*/
@Throws(NotSerializableException::class)
fun <T : Any> serialize(obj: T): SerializedBytes<T> {
try {
val data = Data.Factory.create()
data.withDescribed(Envelope.DESCRIPTOR_OBJECT) {
withList {
// Our object
writeObject(obj, this)
// The schema
writeSchema(Schema(schemaHistory.toList()), this)
}
}
val bytes = ByteArray(data.encodedSize().toInt() + 8)
val buf = ByteBuffer.wrap(bytes)
buf.put(AmqpHeaderV1_0.bytes)
data.encode(buf)
return SerializedBytes(bytes)
} finally {
objectHistory.clear()
serializerHistory.clear()
schemaHistory.clear()
}
}
internal fun writeObject(obj: Any, data: Data) {
writeObject(obj, data, obj.javaClass)
}
open fun writeSchema(schema: Schema, data: Data) {
data.putObject(schema)
}
internal fun writeObjectOrNull(obj: Any?, data: Data, type: Type) {
if (obj == null) {
data.putNull()
} else {
writeObject(obj, data, if (type == SerializerFactory.AnyType) obj.javaClass else type)
}
}
internal fun writeObject(obj: Any, data: Data, type: Type) {
val serializer = serializerFactory.get(obj.javaClass, type)
if (serializer !in serializerHistory) {
serializerHistory.add(serializer)
serializer.writeClassInfo(this)
}
serializer.writeObject(obj, data, type, this)
}
open internal fun writeTypeNotations(vararg typeNotation: TypeNotation): Boolean {
return schemaHistory.addAll(typeNotation)
}
open internal fun requireSerializer(type: Type) {
if (type != SerializerFactory.AnyType && type != Object::class.java) {
val serializer = serializerFactory.get(null, type)
if (serializer !in serializerHistory) {
serializerHistory.add(serializer)
serializer.writeClassInfo(this)
}
}
}
}

View File

@ -0,0 +1,364 @@
package net.corda.nodeapi.internal.serialization.amqp
import com.google.common.primitives.Primitives
import com.google.common.reflect.TypeResolver
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.internal.serialization.AllWhitelist
import net.corda.nodeapi.internal.serialization.carpenter.CarpenterSchemas
import net.corda.nodeapi.internal.serialization.carpenter.ClassCarpenter
import net.corda.nodeapi.internal.serialization.carpenter.MetaCarpenter
import net.corda.nodeapi.internal.serialization.carpenter.carpenterSchema
import org.apache.qpid.proton.amqp.*
import java.io.NotSerializableException
import java.lang.reflect.GenericArrayType
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import java.lang.reflect.WildcardType
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArrayList
import javax.annotation.concurrent.ThreadSafe
/**
* Factory of serializers designed to be shared across threads and invocations.
*/
// TODO: enums
// TODO: object references - need better fingerprinting?
// TODO: class references? (e.g. cheat with repeated descriptors using a long encoding, like object ref proposal)
// TODO: Inner classes etc. Should we allow? Currently not considered.
// TODO: support for intern-ing of deserialized objects for some core types (e.g. PublicKey) for memory efficiency
// TODO: maybe support for caching of serialized form of some core types for performance
// TODO: profile for performance in general
// TODO: use guava caches etc so not unbounded
// TODO: do we need to support a transient annotation to exclude certain properties?
// TODO: incorporate the class carpenter for classes not on the classpath.
// TODO: apply class loader logic and an "app context" throughout this code.
// TODO: schema evolution solution when the fingerprints do not line up.
// TODO: allow definition of well known types that are left out of the schema.
// TODO: generally map Object to '*' all over the place in the schema and make sure use of '*' amd '?' is consistent and documented in generics.
// TODO: found a document that states textual descriptors are Symbols. Adjust schema class appropriately.
// TODO: document and alert to the fact that classes cannot default superclass/interface properties otherwise they are "erased" due to matching with constructor.
// TODO: type name prefixes for interfaces and abstract classes? Or use label?
// TODO: generic types should define restricted type alias with source of the wildcarded version, I think, if we're to generate classes from schema
// TODO: need to rethink matching of constructor to properties in relation to implementing interfaces and needing those properties etc.
// TODO: need to support super classes as well as interfaces with our current code base... what's involved? If we continue to ban, what is the impact?
@ThreadSafe
class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) {
private val serializersByType = ConcurrentHashMap<Type, AMQPSerializer<Any>>()
private val serializersByDescriptor = ConcurrentHashMap<Any, AMQPSerializer<Any>>()
private val customSerializers = CopyOnWriteArrayList<CustomSerializer<out Any>>()
private val classCarpenter = ClassCarpenter()
/**
* Look up, and manufacture if necessary, a serializer for the given type.
*
* @param actualClass Will be null if there isn't an actual object instance available (e.g. for
* restricted type processing).
*/
@Throws(NotSerializableException::class)
fun get(actualClass: Class<*>?, declaredType: Type): AMQPSerializer<Any> {
val declaredClass = declaredType.asClass() ?: throw NotSerializableException(
"Declared types of $declaredType are not supported.")
val actualType: Type = inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType
val serializer = if (Collection::class.java.isAssignableFrom(declaredClass)) {
serializersByType.computeIfAbsent(declaredType) {
CollectionSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType(
declaredClass, arrayOf(AnyType), null), this)
}
} else if (Map::class.java.isAssignableFrom(declaredClass)) {
serializersByType.computeIfAbsent(declaredClass) {
makeMapSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType(
declaredClass, arrayOf(AnyType, AnyType), null))
}
} else {
makeClassSerializer(actualClass ?: declaredClass, actualType, declaredType)
}
serializersByDescriptor.putIfAbsent(serializer.typeDescriptor, serializer)
return serializer
}
/**
* Try and infer concrete types for any generics type variables for the actual class encountered, based on the declared
* type.
*/
// TODO: test GenericArrayType
private fun inferTypeVariables(actualClass: Class<*>?, declaredClass: Class<*>, declaredType: Type): Type? {
if (declaredType is ParameterizedType) {
return inferTypeVariables(actualClass, declaredClass, declaredType)
} else if (declaredType is Class<*>) {
// Nothing to infer, otherwise we'd have ParameterizedType
return actualClass
} else if (declaredType is GenericArrayType) {
val declaredComponent = declaredType.genericComponentType
return inferTypeVariables(actualClass?.componentType, declaredComponent.asClass()!!, declaredComponent)?.asArray()
} else return null
}
/**
* Try and infer concrete types for any generics type variables for the actual class encountered, based on the declared
* type, which must be a [ParameterizedType].
*/
private fun inferTypeVariables(actualClass: Class<*>?, declaredClass: Class<*>, declaredType: ParameterizedType): Type? {
if (actualClass == null || declaredClass == actualClass) {
return null
} else if (declaredClass.isAssignableFrom(actualClass)) {
return if (actualClass.typeParameters.isNotEmpty()) {
// The actual class can never have type variables resolved, due to the JVM's use of type erasure, so let's try and resolve them
// Search for declared type in the inheritance hierarchy and then see if that fills in all the variables
val implementationChain: List<Type>? = findPathToDeclared(actualClass, declaredType, mutableListOf<Type>())
if (implementationChain != null) {
val start = implementationChain.last()
val rest = implementationChain.dropLast(1).drop(1)
val resolver = rest.reversed().fold(TypeResolver().where(start, declaredType)) {
resolved, chainEntry ->
val newResolved = resolved.resolveType(chainEntry)
TypeResolver().where(chainEntry, newResolved)
}
// The end type is a special case as it is a Class, so we need to fake up a ParameterizedType for it to get the TypeResolver to do anything.
val endType = DeserializedParameterizedType(actualClass, actualClass.typeParameters)
val resolvedType = resolver.resolveType(endType)
resolvedType
} else throw NotSerializableException("No inheritance path between actual $actualClass and declared $declaredType.")
} else actualClass
} else throw NotSerializableException("Found object of type $actualClass in a property expecting $declaredType")
}
// Stop when reach declared type or return null if we don't find it.
private fun findPathToDeclared(startingType: Type, declaredType: Type, chain: MutableList<Type>): List<Type>? {
chain.add(startingType)
val startingClass = startingType.asClass()
if (startingClass == declaredType.asClass()) {
// We're done...
return chain
}
// Now explore potential options of superclass and all interfaces
val superClass = startingClass?.genericSuperclass
val superClassChain = if (superClass != null) {
val resolved = TypeResolver().where(startingClass.asParameterizedType(), startingType.asParameterizedType()).resolveType(superClass)
findPathToDeclared(resolved, declaredType, ArrayList(chain))
} else null
if (superClassChain != null) return superClassChain
for (iface in startingClass?.genericInterfaces ?: emptyArray()) {
val resolved = TypeResolver().where(startingClass!!.asParameterizedType(), startingType.asParameterizedType()).resolveType(iface)
return findPathToDeclared(resolved, declaredType, ArrayList(chain)) ?: continue
}
return null
}
/**
* Lookup and manufacture a serializer for the given AMQP type descriptor, assuming we also have the necessary types
* contained in the [Schema].
*/
@Throws(NotSerializableException::class)
fun get(typeDescriptor: Any, schema: Schema): AMQPSerializer<Any> {
return serializersByDescriptor[typeDescriptor] ?: {
processSchema(schema)
serializersByDescriptor[typeDescriptor] ?: throw NotSerializableException("Could not find type matching descriptor $typeDescriptor.")
}()
}
/**
* Register a custom serializer for any type that cannot be serialized or deserialized by the default serializer
* that expects to find getters and a constructor with a parameter for each property.
*/
fun register(customSerializer: CustomSerializer<out Any>) {
if (!serializersByDescriptor.containsKey(customSerializer.typeDescriptor)) {
customSerializers += customSerializer
serializersByDescriptor[customSerializer.typeDescriptor] = customSerializer
for (additional in customSerializer.additionalSerializers) {
register(additional)
}
}
}
private fun processSchema(schema: Schema, sentinal: Boolean = false) {
val carpenterSchemas = CarpenterSchemas.newInstance()
for (typeNotation in schema.types) {
try {
processSchemaEntry(typeNotation, classCarpenter.classloader)
} catch (e: ClassNotFoundException) {
if (sentinal || (typeNotation !is CompositeType)) throw e
typeNotation.carpenterSchema(
classLoaders = listOf(classCarpenter.classloader), carpenterSchemas = carpenterSchemas)
}
}
if (carpenterSchemas.isNotEmpty()) {
val mc = MetaCarpenter(carpenterSchemas, classCarpenter)
mc.build()
processSchema(schema, true)
}
}
private fun processSchemaEntry(typeNotation: TypeNotation,
cl: ClassLoader = DeserializedParameterizedType::class.java.classLoader) {
when (typeNotation) {
is CompositeType -> processCompositeType(typeNotation, cl) // java.lang.Class (whether a class or interface)
is RestrictedType -> processRestrictedType(typeNotation) // Collection / Map, possibly with generics
}
}
private fun processRestrictedType(typeNotation: RestrictedType) {
// TODO: class loader logic, and compare the schema.
val type = typeForName(typeNotation.name)
get(null, type)
}
private fun processCompositeType(typeNotation: CompositeType,
cl: ClassLoader = DeserializedParameterizedType::class.java.classLoader) {
// TODO: class loader logic, and compare the schema.
val type = typeForName(typeNotation.name, cl)
get(type.asClass() ?: throw NotSerializableException("Unable to build composite type for $type"), type)
}
private fun makeClassSerializer(clazz: Class<*>, type: Type, declaredType: Type): AMQPSerializer<Any> = serializersByType.computeIfAbsent(type) {
if (isPrimitive(clazz)) {
AMQPPrimitiveSerializer(clazz)
} else {
findCustomSerializer(clazz, declaredType) ?: run {
if (type.isArray()) {
whitelisted(type.componentType())
if (clazz.componentType.isPrimitive) PrimArraySerializer.make(type, this)
else ArraySerializer.make(type, this)
} else if (clazz.kotlin.objectInstance != null) {
whitelisted(clazz)
SingletonSerializer(clazz, clazz.kotlin.objectInstance!!, this)
} else {
whitelisted(type)
ObjectSerializer(type, this)
}
}
}
}
internal fun findCustomSerializer(clazz: Class<*>, declaredType: Type): AMQPSerializer<Any>? {
// e.g. Imagine if we provided a Map serializer this way, then it won't work if the declared type is AbstractMap, only Map.
// Otherwise it needs to inject additional schema for a RestrictedType source of the super type. Could be done, but do we need it?
for (customSerializer in customSerializers) {
if (customSerializer.isSerializerFor(clazz)) {
val declaredSuperClass = declaredType.asClass()?.superclass
if (declaredSuperClass == null || !customSerializer.isSerializerFor(declaredSuperClass)) {
return customSerializer
} else {
// Make a subclass serializer for the subclass and return that...
@Suppress("UNCHECKED_CAST")
return CustomSerializer.SubClass<Any>(clazz, customSerializer as CustomSerializer<Any>)
}
}
}
return null
}
private fun whitelisted(type: Type) {
val clazz = type.asClass()!!
if (!whitelist.hasListed(clazz) && !hasAnnotationInHierarchy(clazz)) {
throw NotSerializableException("Class $type is not on the whitelist or annotated with @CordaSerializable.")
}
}
// Recursively check the class, interfaces and superclasses for our annotation.
internal fun hasAnnotationInHierarchy(type: Class<*>): Boolean {
return type.isAnnotationPresent(CordaSerializable::class.java) ||
type.interfaces.any { hasAnnotationInHierarchy(it) }
|| (type.superclass != null && hasAnnotationInHierarchy(type.superclass))
}
private fun makeMapSerializer(declaredType: ParameterizedType): AMQPSerializer<Any> {
val rawType = declaredType.rawType as Class<*>
rawType.checkNotUnorderedHashMap()
return MapSerializer(declaredType, this)
}
companion object {
fun isPrimitive(type: Type): Boolean = primitiveTypeName(type) != null
fun primitiveTypeName(type: Type): String? {
val clazz = type as? Class<*> ?: return null
return primitiveTypeNames[Primitives.unwrap(clazz)]
}
fun primitiveType(type: String): Class<*>? {
return namesOfPrimitiveTypes[type]
}
private val primitiveTypeNames: Map<Class<*>, String> = mapOf(
Character::class.java to "char",
Char::class.java to "char",
Boolean::class.java to "boolean",
Byte::class.java to "byte",
UnsignedByte::class.java to "ubyte",
Short::class.java to "short",
UnsignedShort::class.java to "ushort",
Int::class.java to "int",
UnsignedInteger::class.java to "uint",
Long::class.java to "long",
UnsignedLong::class.java to "ulong",
Float::class.java to "float",
Double::class.java to "double",
Decimal32::class.java to "decimal32",
Decimal64::class.java to "decimal62",
Decimal128::class.java to "decimal128",
Date::class.java to "timestamp",
UUID::class.java to "uuid",
ByteArray::class.java to "binary",
String::class.java to "string",
Symbol::class.java to "symbol")
private val namesOfPrimitiveTypes: Map<String, Class<*>> = primitiveTypeNames.map { it.value to it.key }.toMap()
fun nameForType(type: Type): String = when (type) {
is Class<*> -> {
primitiveTypeName(type) ?: if (type.isArray) {
"${nameForType(type.componentType)}${if (type.componentType.isPrimitive) "[p]" else "[]"}"
} else type.name
}
is ParameterizedType -> "${nameForType(type.rawType)}<${type.actualTypeArguments.joinToString { nameForType(it) }}>"
is GenericArrayType -> "${nameForType(type.genericComponentType)}[]"
else -> throw NotSerializableException("Unable to render type $type to a string.")
}
private fun typeForName(
name: String,
cl: ClassLoader = DeserializedParameterizedType::class.java.classLoader): Type {
return if (name.endsWith("[]")) {
val elementType = typeForName(name.substring(0, name.lastIndex - 1))
if (elementType is ParameterizedType || elementType is GenericArrayType) {
DeserializedGenericArrayType(elementType)
} else if (elementType is Class<*>) {
java.lang.reflect.Array.newInstance(elementType, 0).javaClass
} else {
throw NotSerializableException("Not able to deserialize array type: $name")
}
} else if (name.endsWith("[p]")) {
// There is no need to handle the ByteArray case as that type is coercible automatically
// to the binary type and is thus handled by the main serializer and doesn't need a
// special case for a primitive array of bytes
when (name) {
"int[p]" -> IntArray::class.java
"char[p]" -> CharArray::class.java
"boolean[p]" -> BooleanArray::class.java
"float[p]" -> FloatArray::class.java
"double[p]" -> DoubleArray::class.java
"short[p]" -> ShortArray::class.java
"long[p]" -> LongArray::class.java
else -> throw NotSerializableException("Not able to deserialize array type: $name")
}
} else {
DeserializedParameterizedType.make(name, cl)
}
}
}
object AnyType : WildcardType {
override fun getUpperBounds(): Array<Type> = arrayOf(Object::class.java)
override fun getLowerBounds(): Array<Type> = emptyArray()
override fun toString(): String = "?"
}
}

View File

@ -0,0 +1,32 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.apache.qpid.proton.codec.Data
import java.lang.reflect.Type
/**
* A custom serializer that transports nothing on the wire (except a boolean "false", since AMQP does not support
* absolutely nothing, or null as a described type) when we have a singleton within the node that we just
* want converting back to that singleton instance on the receiving JVM.
*/
class SingletonSerializer(override val type: Class<*>, val singleton: Any, factory: SerializerFactory) : AMQPSerializer<Any> {
override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}"
private val interfaces = interfacesForSerialization(type)
private fun generateProvides(): List<String> = interfaces.map { it.typeName }
internal val typeNotation: TypeNotation = RestrictedType(type.typeName, "Singleton", generateProvides(), "boolean", Descriptor(typeDescriptor, null), emptyList())
override fun writeClassInfo(output: SerializationOutput) {
output.writeTypeNotations(typeNotation)
}
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
data.withDescribed(typeNotation.descriptor) {
data.putBoolean(false)
}
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any {
return singleton
}
}

View File

@ -0,0 +1,11 @@
package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer
import java.math.BigDecimal
/**
* A serializer for [BigDecimal], utilising the string based helper. [BigDecimal] seems to have no import/export
* features that are precision independent other than via a string. The format of the string is discussed in the
* documentation for [BigDecimal.toString].
*/
object BigDecimalSerializer : CustomSerializer.ToString<BigDecimal>(BigDecimal::class.java)

View File

@ -0,0 +1,12 @@
package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer
import java.util.*
/**
* A custom serializer for the [Currency] class, utilizing the currency code string representation.
*/
object CurrencySerializer : CustomSerializer.ToString<Currency>(Currency::class.java,
withInheritance = false,
maker = { Currency.getInstance(it) },
unmaker = { it.currencyCode })

View File

@ -0,0 +1,18 @@
package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
import java.time.Instant
/**
* A serializer for [Instant] that uses a proxy object to write out the seconds since the epoch and the nanos.
*/
class InstantSerializer(factory: SerializerFactory) : CustomSerializer.Proxy<Instant, InstantSerializer.InstantProxy>(Instant::class.java, InstantProxy::class.java, factory) {
override val additionalSerializers: Iterable<CustomSerializer<out Any>> = emptyList()
override fun toProxy(obj: Instant): InstantProxy = InstantProxy(obj.epochSecond, obj.nano)
override fun fromProxy(proxy: InstantProxy): Instant = Instant.ofEpochSecond(proxy.epochSeconds, proxy.nanos.toLong())
data class InstantProxy(val epochSeconds: Long, val nanos: Int)
}

View File

@ -0,0 +1,26 @@
package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.core.crypto.Crypto
import net.corda.nodeapi.internal.serialization.amqp.*
import org.apache.qpid.proton.codec.Data
import java.lang.reflect.Type
import java.security.PublicKey
/**
* A serializer that writes out a public key in X.509 format.
*/
object PublicKeySerializer : CustomSerializer.Implements<PublicKey>(PublicKey::class.java) {
override val additionalSerializers: Iterable<CustomSerializer<out Any>> = emptyList()
override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, descriptor, emptyList())))
override fun writeDescribedObject(obj: PublicKey, data: Data, type: Type, output: SerializationOutput) {
// TODO: Instead of encoding to the default X509 format, we could have a custom per key type (space-efficient) serialiser.
output.writeObject(obj.encoded, data, clazz)
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): PublicKey {
val bits = input.readObject(obj, schema, ByteArray::class.java) as ByteArray
return Crypto.decodePublicKey(bits)
}
}

View File

@ -0,0 +1,81 @@
package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.core.CordaRuntimeException
import net.corda.core.CordaThrowable
import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
import net.corda.nodeapi.internal.serialization.amqp.constructorForDeserialization
import net.corda.nodeapi.internal.serialization.amqp.propertiesForSerialization
import java.io.NotSerializableException
class ThrowableSerializer(factory: SerializerFactory) : CustomSerializer.Proxy<Throwable, ThrowableSerializer.ThrowableProxy>(Throwable::class.java, ThrowableProxy::class.java, factory) {
override val additionalSerializers: Iterable<CustomSerializer<out Any>> = listOf(StackTraceElementSerializer(factory))
override fun toProxy(obj: Throwable): ThrowableProxy {
val extraProperties: MutableMap<String, Any?> = LinkedHashMap()
val message = if (obj is CordaThrowable) {
// Try and find a constructor
try {
val constructor = constructorForDeserialization(obj.javaClass)
val props = propertiesForSerialization(constructor, obj.javaClass, factory)
for (prop in props) {
extraProperties[prop.name] = prop.readMethod.invoke(obj)
}
} catch(e: NotSerializableException) {
}
obj.originalMessage
} else {
obj.message
}
return ThrowableProxy(obj.javaClass.name, message, obj.stackTrace, obj.cause, obj.suppressed, extraProperties)
}
override fun fromProxy(proxy: ThrowableProxy): Throwable {
try {
// TODO: This will need reworking when we have multiple class loaders
val clazz = Class.forName(proxy.exceptionClass, false, this.javaClass.classLoader)
// If it is CordaException or CordaRuntimeException, we can seek any constructor and then set the properties
// Otherwise we just make a CordaRuntimeException
if (CordaThrowable::class.java.isAssignableFrom(clazz) && Throwable::class.java.isAssignableFrom(clazz)) {
val constructor = constructorForDeserialization(clazz)!!
val throwable = constructor.callBy(constructor.parameters.map { it to proxy.additionalProperties[it.name] }.toMap())
(throwable as CordaThrowable).apply {
if (this.javaClass.name != proxy.exceptionClass) this.originalExceptionClassName = proxy.exceptionClass
this.setMessage(proxy.message)
this.setCause(proxy.cause)
this.addSuppressed(proxy.suppressed)
}
return (throwable as Throwable).apply {
this.stackTrace = proxy.stackTrace
}
}
} catch (e: Exception) {
// If attempts to rebuild the exact exception fail, we fall through and build a runtime exception.
}
// If the criteria are not met or we experience an exception constructing the exception, we fall back to our own unchecked exception.
return CordaRuntimeException(proxy.exceptionClass).apply {
this.setMessage(proxy.message)
this.setCause(proxy.cause)
this.stackTrace = proxy.stackTrace
this.addSuppressed(proxy.suppressed)
}
}
class ThrowableProxy(
val exceptionClass: String,
val message: String?,
val stackTrace: Array<StackTraceElement>,
val cause: Throwable?,
val suppressed: Array<Throwable>,
val additionalProperties: Map<String, Any?>)
}
class StackTraceElementSerializer(factory: SerializerFactory) : CustomSerializer.Proxy<StackTraceElement, StackTraceElementSerializer.StackTraceElementProxy>(StackTraceElement::class.java, StackTraceElementProxy::class.java, factory) {
override val additionalSerializers: Iterable<CustomSerializer<Any>> = emptyList()
override fun toProxy(obj: StackTraceElement): StackTraceElementProxy = StackTraceElementProxy(obj.className, obj.methodName, obj.fileName, obj.lineNumber)
override fun fromProxy(proxy: StackTraceElementProxy): StackTraceElement = StackTraceElement(proxy.declaringClass, proxy.methodName, proxy.fileName, proxy.lineNumber)
data class StackTraceElementProxy(val declaringClass: String, val methodName: String, val fileName: String?, val lineNumber: Int)
}

View File

@ -0,0 +1,25 @@
package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.nodeapi.internal.serialization.amqp.*
import org.apache.qpid.proton.codec.Data
import org.bouncycastle.asn1.ASN1InputStream
import org.bouncycastle.asn1.x500.X500Name
import java.lang.reflect.Type
/**
* Custom serializer for X500 names that utilizes their ASN.1 encoding on the wire.
*/
object X500NameSerializer : CustomSerializer.Implements<X500Name>(X500Name::class.java) {
override val additionalSerializers: Iterable<CustomSerializer<out Any>> = emptyList()
override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, descriptor, emptyList())))
override fun writeDescribedObject(obj: X500Name, data: Data, type: Type, output: SerializationOutput) {
output.writeObject(obj.encoded, data, clazz)
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): X500Name {
val binary = input.readObject(obj, schema, ByteArray::class.java) as ByteArray
return X500Name.getInstance(ASN1InputStream(binary).readObject())
}
}

View File

@ -0,0 +1,142 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.nodeapi.internal.serialization.amqp.CompositeType
import net.corda.nodeapi.internal.serialization.amqp.Field as AMQPField
import net.corda.nodeapi.internal.serialization.amqp.Schema as AMQPSchema
fun AMQPSchema.carpenterSchema(
loaders: List<ClassLoader> = listOf<ClassLoader>(ClassLoader.getSystemClassLoader()))
: CarpenterSchemas {
val rtn = CarpenterSchemas.newInstance()
types.filterIsInstance<CompositeType>().forEach {
it.carpenterSchema(classLoaders = loaders, carpenterSchemas = rtn)
}
return rtn
}
/**
* if we can load the class then we MUST know about all of it's composite elements
*/
private fun CompositeType.validatePropertyTypes(
classLoaders: List<ClassLoader> = listOf<ClassLoader>(ClassLoader.getSystemClassLoader())) {
fields.forEach {
if (!it.validateType(classLoaders)) throw UncarpentableException(name, it.name, it.type)
}
}
fun AMQPField.typeAsString() = if (type == "*") requires[0] else type
/**
* based upon this AMQP schema either
* a) add the corresponding carpenter schema to the [carpenterSchemas] param
* b) add the class to the dependency tree in [carpenterSchemas] if it cannot be instantiated
* at this time
*
* @param classLoaders list of classLoaders, defaulting toe the system class loader, that might
* be used to load objects
* @param carpenterSchemas structure that holds the dependency tree and list of classes that
* need constructing
* @param force by default a schema is not added to [carpenterSchemas] if it already exists
* on the class path. For testing purposes schema generation can be forced
*/
fun CompositeType.carpenterSchema(
classLoaders: List<ClassLoader> = listOf<ClassLoader>(ClassLoader.getSystemClassLoader()),
carpenterSchemas: CarpenterSchemas,
force: Boolean = false) {
if (classLoaders.exists(name)) {
validatePropertyTypes(classLoaders)
if (!force) return
}
val providesList = mutableListOf<Class<*>>()
var isInterface = false
var isCreatable = true
provides.forEach {
if (name == it) {
isInterface = true
return@forEach
}
try {
providesList.add(classLoaders.loadIfExists(it))
} catch (e: ClassNotFoundException) {
carpenterSchemas.addDepPair(this, name, it)
isCreatable = false
}
}
val m: MutableMap<String, Field> = mutableMapOf()
fields.forEach {
try {
m[it.name] = FieldFactory.newInstance(it.mandatory, it.name, it.getTypeAsClass(classLoaders))
} catch (e: ClassNotFoundException) {
carpenterSchemas.addDepPair(this, name, it.typeAsString())
isCreatable = false
}
}
if (isCreatable) {
carpenterSchemas.carpenterSchemas.add(CarpenterSchemaFactory.newInstance(
name = name,
fields = m,
interfaces = providesList,
isInterface = isInterface))
}
}
// map a pair of (typename, mandatory) to the corresponding class type
// where the mandatory AMQP flag maps to the types nullability
val typeStrToType: Map<Pair<String, Boolean>, Class<out Any?>> = mapOf(
Pair("int", true) to Int::class.javaPrimitiveType!!,
Pair("int", false) to Integer::class.javaObjectType,
Pair("short", true) to Short::class.javaPrimitiveType!!,
Pair("short", false) to Short::class.javaObjectType,
Pair("long", true) to Long::class.javaPrimitiveType!!,
Pair("long", false) to Long::class.javaObjectType,
Pair("char", true) to Char::class.javaPrimitiveType!!,
Pair("char", false) to java.lang.Character::class.java,
Pair("boolean", true) to Boolean::class.javaPrimitiveType!!,
Pair("boolean", false) to Boolean::class.javaObjectType,
Pair("double", true) to Double::class.javaPrimitiveType!!,
Pair("double", false) to Double::class.javaObjectType,
Pair("float", true) to Float::class.javaPrimitiveType!!,
Pair("float", false) to Float::class.javaObjectType,
Pair("byte", true) to Byte::class.javaPrimitiveType!!,
Pair("byte", false) to Byte::class.javaObjectType
)
fun AMQPField.getTypeAsClass(
classLoaders: List<ClassLoader> = listOf<ClassLoader>(ClassLoader.getSystemClassLoader())
) = typeStrToType[Pair(type, mandatory)] ?: when (type) {
"string" -> String::class.java
"*" -> classLoaders.loadIfExists(requires[0])
else -> classLoaders.loadIfExists(type)
}
fun AMQPField.validateType(
classLoaders: List<ClassLoader> = listOf<ClassLoader>(ClassLoader.getSystemClassLoader())
) = when (type) {
"byte", "int", "string", "short", "long", "char", "boolean", "double", "float" -> true
"*" -> classLoaders.exists(requires[0])
else -> classLoaders.exists(type)
}
private fun List<ClassLoader>.exists(clazz: String) = this.find {
try { it.loadClass(clazz); true } catch (e: ClassNotFoundException) { false }
} != null
private fun List<ClassLoader>.loadIfExists(clazz: String): Class<*> {
this.forEach {
try {
return it.loadClass(clazz)
} catch (e: ClassNotFoundException) {
return@forEach
}
}
throw ClassNotFoundException(clazz)
}

View File

@ -0,0 +1,326 @@
package net.corda.nodeapi.internal.serialization.carpenter
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.*
import java.lang.Character.isJavaIdentifierPart
import java.lang.Character.isJavaIdentifierStart
import java.util.*
/**
* Any object that implements this interface is expected to expose its own fields via the [get] method, exactly
* as if `this.class.getMethod("get" + name.capitalize()).invoke(this)` had been called. It is intended as a more
* convenient alternative to reflection.
*/
interface SimpleFieldAccess {
operator fun get(name: String): Any?
}
class CarpenterClassLoader : ClassLoader(Thread.currentThread().contextClassLoader) {
fun load(name: String, bytes: ByteArray) = defineClass(name, bytes, 0, bytes.size)
}
/**
* A class carpenter generates JVM bytecodes for a class given a schema and then loads it into a sub-classloader.
* The generated classes have getters, a toString method and implement a simple property access interface. The
* resulting class can then be accessed via reflection APIs, or cast to one of the requested interfaces.
*
* Additional interfaces may be requested if they consist purely of get methods and the schema matches.
*
* # Discussion
*
* This class may initially appear pointless: why create a class at runtime that simply holds data and which
* you cannot compile against? The purpose is to enable the synthesis of data classes based on (AMQP) schemas
* when the app that originally defined them is not available on the classpath. Whilst the getters and setters
* are not usable directly, many existing reflection based frameworks like JSON/XML processors, Swing property
* editor sheets, Groovy and so on can work with the JavaBean ("POJO") format. Feeding these objects to such
* frameworks can often be useful. The generic property access interface is helpful if you want to write code
* that accesses these schemas but don't want to actually define/depend on the classes themselves.
*
* # Usage notes
*
* This class is not thread safe.
*
* The generated class has private final fields and getters for each field. The constructor has one parameter
* for each field. In this sense it is like a Kotlin data class.
*
* The generated class implements [SimpleFieldAccess]. The get method takes the name of the field, not the name
* of a getter i.e. use .get("someVar") not .get("getSomeVar") or in Kotlin you can use square brackets syntax.
*
* The generated class implements toString() using Google Guava to simplify formatting. Make sure it's on the
* classpath of the generated classes.
*
* Generated classes can refer to each other as long as they're defined in the right order. They can also
* inherit from each other. When inheritance is used the constructor requires parameters in order of superclasses
* first, child class last.
*
* You cannot create boxed primitive fields with this class: fields are always of primitive type.
*
* Nullability information is not emitted.
*
* Each [ClassCarpenter] defines its own classloader and thus, its own namespace. If you create multiple
* carpenters, you can load the same schema with the same name and get two different classes, whose objects
* will not be interoperable.
*
* Equals/hashCode methods are not yet supported.
*/
class ClassCarpenter {
// TODO: Generics.
// TODO: Sandbox the generated code when a security manager is in use.
// TODO: Generate equals/hashCode.
// TODO: Support annotations.
// TODO: isFoo getter patterns for booleans (this is what Kotlin generates)
val classloader = CarpenterClassLoader()
private val _loaded = HashMap<String, Class<*>>()
private val String.jvm: String get() = replace(".", "/")
/** Returns a snapshot of the currently loaded classes as a map of full class name (package names+dots) -> class object */
val loaded: Map<String, Class<*>> = HashMap(_loaded)
/**
* Generate bytecode for the given schema and load into the JVM. The returned class object can be used to
* construct instances of the generated class.
*
* @throws DuplicateNameException if the schema's name is already taken in this namespace (you can create a
* new ClassCarpenter if you're OK with ambiguous names)
*/
fun build(schema: Schema): Class<*> {
validateSchema(schema)
// Walk up the inheritance hierarchy and then start walking back down once we either hit the top, or
// find a class we haven't generated yet.
val hierarchy = ArrayList<Schema>()
hierarchy += schema
var cursor = schema.superclass
while (cursor != null && cursor.name !in _loaded) {
hierarchy += cursor
cursor = cursor.superclass
}
hierarchy.reversed().forEach {
when (it) {
is InterfaceSchema -> generateInterface(it)
is ClassSchema -> generateClass(it)
}
}
assert (schema.name in _loaded)
return _loaded[schema.name]!!
}
private fun generateInterface(interfaceSchema: Schema): Class<*> {
return generate(interfaceSchema) { cw, schema ->
val interfaces = schema.interfaces.map { it.name.jvm }.toTypedArray()
with(cw) {
visit(V1_8, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, "java/lang/Object", interfaces)
generateAbstractGetters(schema)
visitEnd()
}
}
}
private fun generateClass(classSchema: Schema): Class<*> {
return generate(classSchema) { cw, schema ->
val superName = schema.superclass?.jvmName ?: "java/lang/Object"
val interfaces = schema.interfaces.map { it.name.jvm }.toMutableList()
if (SimpleFieldAccess::class.java !in schema.interfaces) interfaces.add(SimpleFieldAccess::class.java.name.jvm)
with(cw) {
visit(V1_8, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces.toTypedArray())
generateFields(schema)
generateConstructor(schema)
generateGetters(schema)
if (schema.superclass == null)
generateGetMethod() // From SimplePropertyAccess
generateToString(schema)
visitEnd()
}
}
}
private fun generate(schema: Schema, generator: (ClassWriter, Schema) -> Unit): Class<*> {
// Lazy: we could compute max locals/max stack ourselves, it'd be faster.
val cw = ClassWriter(ClassWriter.COMPUTE_FRAMES or ClassWriter.COMPUTE_MAXS)
generator(cw, schema)
val clazz = classloader.load(schema.name, cw.toByteArray())
_loaded[schema.name] = clazz
return clazz
}
private fun ClassWriter.generateFields(schema: Schema) {
schema.fields.forEach { it.value.generateField(this) }
}
private fun ClassWriter.generateToString(schema: Schema) {
val toStringHelper = "com/google/common/base/MoreObjects\$ToStringHelper"
with(visitMethod(ACC_PUBLIC, "toString", "()Ljava/lang/String;", null, null)) {
visitCode()
// com.google.common.base.MoreObjects.toStringHelper("TypeName")
visitLdcInsn(schema.name.split('.').last())
visitMethodInsn(INVOKESTATIC, "com/google/common/base/MoreObjects", "toStringHelper", "(Ljava/lang/String;)L$toStringHelper;", false)
// Call the add() methods.
for ((name, field) in schema.fieldsIncludingSuperclasses().entries) {
visitLdcInsn(name)
visitVarInsn(ALOAD, 0) // this
visitFieldInsn(GETFIELD, schema.jvmName, name, schema.descriptorsIncludingSuperclasses()[name])
visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "add", "(Ljava/lang/String;${field.type})L$toStringHelper;", false)
}
// call toString() on the builder and return.
visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "toString", "()Ljava/lang/String;", false)
visitInsn(ARETURN)
visitMaxs(0, 0)
visitEnd()
}
}
private fun ClassWriter.generateGetMethod() {
val ourJvmName = ClassCarpenter::class.java.name.jvm
with(visitMethod(ACC_PUBLIC, "get", "(Ljava/lang/String;)Ljava/lang/Object;", null, null)) {
visitCode()
visitVarInsn(ALOAD, 0) // Load 'this'
visitVarInsn(ALOAD, 1) // Load the name argument
// Using this generic helper method is slow, as it relies on reflection. A faster way would be
// to use a tableswitch opcode, or just push back on the user and ask them to use actual reflection
// or MethodHandles (super fast reflection) to access the object instead.
visitMethodInsn(INVOKESTATIC, ourJvmName, "getField", "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false)
visitInsn(ARETURN)
visitMaxs(0, 0)
visitEnd()
}
}
private fun ClassWriter.generateGetters(schema: Schema) {
for ((name, type) in schema.fields) {
with(visitMethod(ACC_PUBLIC, "get" + name.capitalize(), "()" + type.descriptor, null, null)) {
type.addNullabilityAnnotation(this)
visitCode()
visitVarInsn(ALOAD, 0) // Load 'this'
visitFieldInsn(GETFIELD, schema.jvmName, name, type.descriptor)
when (type.field) {
java.lang.Boolean.TYPE, Integer.TYPE, java.lang.Short.TYPE, java.lang.Byte.TYPE,
java.lang.Character.TYPE -> visitInsn(IRETURN)
java.lang.Long.TYPE -> visitInsn(LRETURN)
java.lang.Double.TYPE -> visitInsn(DRETURN)
java.lang.Float.TYPE -> visitInsn(FRETURN)
else -> visitInsn(ARETURN)
}
visitMaxs(0, 0)
visitEnd()
}
}
}
private fun ClassWriter.generateAbstractGetters(schema: Schema) {
for ((name, field) in schema.fields) {
val opcodes = ACC_ABSTRACT + ACC_PUBLIC
with(visitMethod(opcodes, "get" + name.capitalize(), "()${field.descriptor}", null, null)) {
// abstract method doesn't have any implementation so just end
visitEnd()
}
}
}
private fun ClassWriter.generateConstructor(schema: Schema) {
with(visitMethod(
ACC_PUBLIC,
"<init>",
"(" + schema.descriptorsIncludingSuperclasses().values.joinToString("") + ")V",
null,
null))
{
var idx = 0
schema.fields.values.forEach { it.visitParameter(this, idx++) }
visitCode()
// Calculate the super call.
val superclassFields = schema.superclass?.fieldsIncludingSuperclasses() ?: emptyMap()
visitVarInsn(ALOAD, 0)
val sc = schema.superclass
if (sc == null) {
visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V", false)
} else {
var slot = 1
superclassFields.values.forEach { slot += load(slot, it) }
val superDesc = sc.descriptorsIncludingSuperclasses().values.joinToString("")
visitMethodInsn(INVOKESPECIAL, sc.name.jvm, "<init>", "($superDesc)V", false)
}
// Assign the fields from parameters.
var slot = 1 + superclassFields.size
for ((name, field) in schema.fields.entries) {
field.nullTest(this, slot)
visitVarInsn(ALOAD, 0) // Load 'this' onto the stack
slot += load(slot, field) // Load the contents of the parameter onto the stack.
visitFieldInsn(PUTFIELD, schema.jvmName, name, field.descriptor)
}
visitInsn(RETURN)
visitMaxs(0, 0)
visitEnd()
}
}
private fun MethodVisitor.load(slot: Int, type: Field): Int {
when (type.field) {
java.lang.Boolean.TYPE, Integer.TYPE, java.lang.Short.TYPE, java.lang.Byte.TYPE,
java.lang.Character.TYPE -> visitVarInsn(ILOAD, slot)
java.lang.Long.TYPE -> visitVarInsn(LLOAD, slot)
java.lang.Double.TYPE -> visitVarInsn(DLOAD, slot)
java.lang.Float.TYPE -> visitVarInsn(FLOAD, slot)
else -> visitVarInsn(ALOAD, slot)
}
return when (type.field) {
java.lang.Long.TYPE, java.lang.Double.TYPE -> 2
else -> 1
}
}
private fun validateSchema(schema: Schema) {
if (schema.name in _loaded) throw DuplicateNameException()
fun isJavaName(n: String) = n.isNotBlank() && isJavaIdentifierStart(n.first()) && n.all(::isJavaIdentifierPart)
require(isJavaName(schema.name.split(".").last())) { "Not a valid Java name: ${schema.name}" }
schema.fields.keys.forEach { require(isJavaName(it)) { "Not a valid Java name: $it" } }
// Now check each interface we've been asked to implement, as the JVM will unfortunately only catch the
// fact that we didn't implement the interface we said we would at the moment the missing method is
// actually called, which is a bit too dynamic for my tastes.
val allFields = schema.fieldsIncludingSuperclasses()
for (itf in schema.interfaces) {
itf.methods.forEach {
val fieldNameFromItf = when {
it.name.startsWith("get") -> it.name.substring(3).decapitalize()
else -> throw InterfaceMismatchException(
"Requested interfaces must consist only of methods that start "
+ "with 'get': ${itf.name}.${it.name}")
}
// If we're trying to carpent a class that prior to serialisation / deserialisation
// was made by a carpenter then we can ignore this (it will implement a plain get
// method from SimpleFieldAccess).
if (fieldNameFromItf.isEmpty() && SimpleFieldAccess::class.java in schema.interfaces) return@forEach
if ((schema is ClassSchema) and (fieldNameFromItf !in allFields))
throw InterfaceMismatchException(
"Interface ${itf.name} requires a field named $fieldNameFromItf but that "
+ "isn't found in the schema or any superclass schemas")
}
}
}
companion object {
@JvmStatic @Suppress("UNUSED")
fun getField(obj: Any, name: String): Any? = obj.javaClass.getMethod("get" + name.capitalize()).invoke(obj)
}
}

View File

@ -0,0 +1,11 @@
package net.corda.nodeapi.internal.serialization.carpenter
class DuplicateNameException : RuntimeException (
"An attempt was made to register two classes with the same name within the same ClassCarpenter namespace.")
class InterfaceMismatchException(msg: String) : RuntimeException(msg)
class NullablePrimitiveException(msg: String) : RuntimeException(msg)
class UncarpentableException (name: String, field: String, type: String) :
Exception ("Class $name is loadable yet contains field $field of unknown type $type")

View File

@ -0,0 +1,107 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.nodeapi.internal.serialization.amqp.CompositeType
import net.corda.nodeapi.internal.serialization.amqp.TypeNotation
/**
* Generated from an AMQP schema this class represents the classes unknown to the deserialiser and that thusly
* require carpenting up in bytecode form. This is a multi step process as carpenting one object may be depedent
* upon the creation of others, this information is tracked in the dependency tree represented by
* [dependencies] and [dependsOn]. Creatable classes are stored in [carpenterSchemas].
*
* The state of this class after initial generation is expected to mutate as classes are built by the carpenter
* enablaing the resolution of dependencies and thus new carpenter schemas added whilst those already
* carpented schemas are removed.
*
* @property carpenterSchemas The list of carpentable classes
* @property dependencies Maps a class to a list of classes that depend on it being built first
* @property dependsOn Maps a class to a list of classes it depends on being built before it
*
* Once a class is constructed we can quickly check for resolution by first looking at all of its dependents in the
* [dependencies] map. This will give us a list of classes that depended on that class being carpented. We can then
* in turn look up all of those classes in the [dependsOn] list, remove their dependency on the newly created class,
* and if that list is reduced to zero know we can now generate a [Schema] for them and carpent them up
*/
data class CarpenterSchemas (
val carpenterSchemas: MutableList<Schema>,
val dependencies: MutableMap<String, Pair<TypeNotation, MutableList<String>>>,
val dependsOn: MutableMap<String, MutableList<String>>) {
companion object CarpenterSchemaConstructor {
fun newInstance(): CarpenterSchemas {
return CarpenterSchemas(
mutableListOf<Schema>(),
mutableMapOf<String, Pair<TypeNotation, MutableList<String>>>(),
mutableMapOf<String, MutableList<String>>())
}
}
fun addDepPair(type: TypeNotation, dependant: String, dependee: String) {
dependsOn.computeIfAbsent(dependee, { mutableListOf<String>() }).add(dependant)
dependencies.computeIfAbsent(dependant, { Pair(type, mutableListOf<String>()) }).second.add(dependee)
}
val size
get() = carpenterSchemas.size
fun isEmpty() = carpenterSchemas.isEmpty()
fun isNotEmpty() = carpenterSchemas.isNotEmpty()
}
/**
* Take a dependency tree of [CarpenterSchemas] and reduce it to zero by carpenting those classes that
* require it. As classes are carpented check for depdency resolution, if now free generate a [Schema] for
* that class and add it to the list of classes ([CarpenterSchemas.carpenterSchemas]) that require
* carpenting
*
* @property cc a reference to the actual class carpenter we're using to constuct classes
* @property objects a list of carpented classes loaded into the carpenters class loader
*/
abstract class MetaCarpenterBase (val schemas : CarpenterSchemas, val cc : ClassCarpenter = ClassCarpenter()) {
val objects = mutableMapOf<String, Class<*>>()
fun step(newObject: Schema) {
objects[newObject.name] = cc.build (newObject)
// go over the list of everything that had a dependency on the newly
// carpented class existing and remove it from their dependency list, If that
// list is now empty we have no impediment to carpenting that class up
schemas.dependsOn.remove(newObject.name)?.forEach { dependent ->
assert (newObject.name in schemas.dependencies[dependent]!!.second)
schemas.dependencies[dependent]?.second?.remove(newObject.name)
// we're out of blockers so we can now create the type
if (schemas.dependencies[dependent]?.second?.isEmpty() ?: false) {
(schemas.dependencies.remove (dependent)?.first as CompositeType).carpenterSchema (
classLoaders = listOf<ClassLoader> (
ClassLoader.getSystemClassLoader(),
cc.classloader),
carpenterSchemas = schemas)
}
}
}
abstract fun build()
val classloader : ClassLoader
get() = cc.classloader
}
class MetaCarpenter(schemas : CarpenterSchemas,
cc : ClassCarpenter = ClassCarpenter()) : MetaCarpenterBase(schemas, cc) {
override fun build() {
while (schemas.carpenterSchemas.isNotEmpty()) {
val newObject = schemas.carpenterSchemas.removeAt(0)
step (newObject)
}
}
}
class TestMetaCarpenter(schemas : CarpenterSchemas,
cc : ClassCarpenter = ClassCarpenter()) : MetaCarpenterBase(schemas, cc) {
override fun build() {
if (schemas.carpenterSchemas.isEmpty()) return
step (schemas.carpenterSchemas.removeAt(0))
}
}

View File

@ -0,0 +1,148 @@
package net.corda.nodeapi.internal.serialization.carpenter
import jdk.internal.org.objectweb.asm.Opcodes.*
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Type
import java.util.*
/**
* A Schema represents a desired class.
*/
abstract class Schema(
val name: String,
fields: Map<String, Field>,
val superclass: Schema? = null,
val interfaces: List<Class<*>> = emptyList())
{
private fun Map<String, Field>.descriptors() =
LinkedHashMap(this.mapValues { it.value.descriptor })
/* Fix the order up front if the user didn't, inject the name into the field as it's
neater when iterating */
val fields = LinkedHashMap(fields.mapValues { it.value.copy(it.key, it.value.field) })
fun fieldsIncludingSuperclasses(): Map<String, Field> =
(superclass?.fieldsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(fields)
fun descriptorsIncludingSuperclasses(): Map<String, String> =
(superclass?.descriptorsIncludingSuperclasses() ?: emptyMap()) + fields.descriptors()
val jvmName: String
get() = name.replace(".", "/")
}
class ClassSchema(
name: String,
fields: Map<String, Field>,
superclass: Schema? = null,
interfaces: List<Class<*>> = emptyList()
) : Schema(name, fields, superclass, interfaces)
class InterfaceSchema(
name: String,
fields: Map<String, Field>,
superclass: Schema? = null,
interfaces: List<Class<*>> = emptyList()
) : Schema(name, fields, superclass, interfaces)
object CarpenterSchemaFactory {
fun newInstance (
name: String,
fields: Map<String, Field>,
superclass: Schema? = null,
interfaces: List<Class<*>> = emptyList(),
isInterface: Boolean = false
) : Schema =
if (isInterface) InterfaceSchema (name, fields, superclass, interfaces)
else ClassSchema (name, fields, superclass, interfaces)
}
abstract class Field(val field: Class<out Any?>) {
companion object {
const val unsetName = "Unset"
}
var name: String = unsetName
abstract val nullabilityAnnotation: String
val descriptor: String
get() = Type.getDescriptor(this.field)
val type: String
get() = if (this.field.isPrimitive) this.descriptor else "Ljava/lang/Object;"
fun generateField(cw: ClassWriter) {
val fieldVisitor = cw.visitField(ACC_PROTECTED + ACC_FINAL, name, descriptor, null, null)
fieldVisitor.visitAnnotation(nullabilityAnnotation, true).visitEnd()
fieldVisitor.visitEnd()
}
fun addNullabilityAnnotation(mv: MethodVisitor) {
mv.visitAnnotation(nullabilityAnnotation, true).visitEnd()
}
fun visitParameter(mv: MethodVisitor, idx: Int) {
with(mv) {
visitParameter(name, 0)
if (!field.isPrimitive) {
visitParameterAnnotation(idx, nullabilityAnnotation, true).visitEnd()
}
}
}
abstract fun copy(name: String, field: Class<out Any?>): Field
abstract fun nullTest(mv: MethodVisitor, slot: Int)
}
class NonNullableField(field: Class<out Any?>) : Field(field) {
override val nullabilityAnnotation = "Ljavax/annotation/Nonnull;"
constructor(name: String, field: Class<out Any?>) : this(field) {
this.name = name
}
override fun copy(name: String, field: Class<out Any?>) = NonNullableField(name, field)
override fun nullTest(mv: MethodVisitor, slot: Int) {
assert(name != unsetName)
if (!field.isPrimitive) {
with(mv) {
visitVarInsn(ALOAD, 0) // load this
visitVarInsn(ALOAD, slot) // load parameter
visitLdcInsn("param \"$name\" cannot be null")
visitMethodInsn(INVOKESTATIC,
"java/util/Objects",
"requireNonNull",
"(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false)
visitInsn(POP)
}
}
}
}
class NullableField(field: Class<out Any?>) : Field(field) {
override val nullabilityAnnotation = "Ljavax/annotation/Nullable;"
constructor(name: String, field: Class<out Any?>) : this(field) {
if (field.isPrimitive) {
throw NullablePrimitiveException (
"Field $name is primitive type ${Type.getDescriptor(field)} and thus cannot be nullable")
}
this.name = name
}
override fun copy(name: String, field: Class<out Any?>) = NullableField(name, field)
override fun nullTest(mv: MethodVisitor, slot: Int) {
assert(name != unsetName)
}
}
object FieldFactory {
fun newInstance (mandatory: Boolean, name: String, field: Class<out Any?>) =
if (mandatory) NonNullableField (name, field) else NullableField (name, field)
}

View File

@ -1,2 +1,2 @@
# Register a ServiceLoader service extending from net.corda.core.node.CordaPluginRegistry
net.corda.nodeapi.serialization.DefaultWhitelist
net.corda.nodeapi.internal.serialization.DefaultWhitelist

View File

@ -0,0 +1,331 @@
package net.corda.nodeapi
import com.nhaarman.mockito_kotlin.mock
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party
import net.corda.core.node.ServiceHub
import net.corda.core.node.services.AttachmentStorage
import net.corda.core.serialization.*
import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl
import net.corda.nodeapi.internal.serialization.WireTransactionSerializer
import net.corda.nodeapi.internal.serialization.withTokenContext
import net.corda.testing.DUMMY_NOTARY
import net.corda.testing.MEGA_CORP
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.node.MockAttachmentStorage
import org.apache.commons.io.IOUtils
import org.junit.Assert
import org.junit.Test
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.net.URL
import java.net.URLClassLoader
import java.util.jar.JarOutputStream
import java.util.zip.ZipEntry
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
interface DummyContractBackdoor {
fun generateInitial(owner: PartyAndReference, magicNumber: Int, notary: Party): TransactionBuilder
fun inspectState(state: ContractState): Int
}
val ATTACHMENT_TEST_PROGRAM_ID = AttachmentClassLoaderTests.AttachmentDummyContract()
class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
companion object {
val ISOLATED_CONTRACTS_JAR_PATH: URL = AttachmentClassLoaderTests::class.java.getResource("isolated.jar")
private fun SerializationContext.withAttachmentStorage(attachmentStorage: AttachmentStorage): SerializationContext {
val serviceHub = mock<ServiceHub>()
whenever(serviceHub.attachments).thenReturn(attachmentStorage)
return this.withTokenContext(SerializeAsTokenContextImpl(serviceHub) {}).withProperty(WireTransactionSerializer.attachmentsClassLoaderEnabled, true)
}
}
class AttachmentDummyContract : Contract {
data class State(val magicNumber: Int = 0) : ContractState {
override val contract = ATTACHMENT_TEST_PROGRAM_ID
override val participants: List<AbstractParty>
get() = listOf()
}
interface Commands : CommandData {
class Create : TypeOnlyCommandData(), Commands
}
override fun verify(tx: LedgerTransaction) {
// Always accepts.
}
// The "empty contract"
override val legalContractReference: SecureHash = SecureHash.sha256("")
fun generateInitial(owner: PartyAndReference, magicNumber: Int, notary: Party): TransactionBuilder {
val state = State(magicNumber)
return TransactionBuilder(notary).withItems(state, Command(Commands.Create(), owner.party.owningKey))
}
}
fun importJar(storage: AttachmentStorage) = ISOLATED_CONTRACTS_JAR_PATH.openStream().use { storage.importAttachment(it) }
// These ClassLoaders work together to load 'AnotherDummyContract' in a disposable way, such that even though
// the class may be on the unit test class path (due to default IDE settings, etc), it won't be loaded into the
// regular app classloader but rather than ClassLoaderForTests. This helps keep our environment clean and
// ensures we have precise control over where it's loaded.
object FilteringClassLoader : ClassLoader() {
override fun loadClass(name: String, resolve: Boolean): Class<*>? {
if ("AnotherDummyContract" in name) {
return null
} else
return super.loadClass(name, resolve)
}
}
class ClassLoaderForTests : URLClassLoader(arrayOf(ISOLATED_CONTRACTS_JAR_PATH), FilteringClassLoader)
@Test
fun `dynamically load AnotherDummyContract from isolated contracts jar`() {
val child = ClassLoaderForTests()
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, child)
val contract = contractClass.newInstance() as Contract
assertEquals(SecureHash.sha256("https://anotherdummy.org"), contract.legalContractReference)
}
fun fakeAttachment(filepath: String, content: String): ByteArray {
val bs = ByteArrayOutputStream()
val js = JarOutputStream(bs)
js.putNextEntry(ZipEntry(filepath))
js.writer().apply { append(content); flush() }
js.closeEntry()
js.close()
return bs.toByteArray()
}
fun readAttachment(attachment: Attachment, filepath: String): ByteArray {
ByteArrayOutputStream().use {
attachment.extractFile(filepath, it)
return it.toByteArray()
}
}
@Test
fun `test MockAttachmentStorage open as jar`() {
val storage = MockAttachmentStorage()
val key = importJar(storage)
val attachment = storage.openAttachment(key)!!
val jar = attachment.openAsJAR()
assertNotNull(jar.nextEntry)
}
@Test
fun `test overlapping file exception`() {
val storage = MockAttachmentStorage()
val att0 = importJar(storage)
val att1 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file.txt", "some data")))
val att2 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file.txt", "some other data")))
assertFailsWith(AttachmentsClassLoader.OverlappingAttachments::class) {
AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! })
}
}
@Test
fun `basic`() {
val storage = MockAttachmentStorage()
val att0 = importJar(storage)
val att1 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file1.txt", "some data")))
val att2 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file2.txt", "some other data")))
val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! })
val txt = IOUtils.toString(cl.getResourceAsStream("file1.txt"), Charsets.UTF_8.name())
assertEquals("some data", txt)
}
@Test
fun `Check platform independent path handling in attachment jars`() {
val storage = MockAttachmentStorage()
val att1 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("/folder1/foldera/file1.txt", "some data")))
val att2 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("\\folder1\\folderb\\file2.txt", "some other data")))
val data1a = readAttachment(storage.openAttachment(att1)!!, "/folder1/foldera/file1.txt")
Assert.assertArrayEquals("some data".toByteArray(), data1a)
val data1b = readAttachment(storage.openAttachment(att1)!!, "\\folder1\\foldera\\file1.txt")
Assert.assertArrayEquals("some data".toByteArray(), data1b)
val data2a = readAttachment(storage.openAttachment(att2)!!, "\\folder1\\folderb\\file2.txt")
Assert.assertArrayEquals("some other data".toByteArray(), data2a)
val data2b = readAttachment(storage.openAttachment(att2)!!, "/folder1/folderb/file2.txt")
Assert.assertArrayEquals("some other data".toByteArray(), data2b)
}
@Test
fun `loading class AnotherDummyContract`() {
val storage = MockAttachmentStorage()
val att0 = importJar(storage)
val att1 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file1.txt", "some data")))
val att2 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file2.txt", "some other data")))
val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader)
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl)
val contract = contractClass.newInstance() as Contract
assertEquals(cl, contract.javaClass.classLoader)
assertEquals(SecureHash.sha256("https://anotherdummy.org"), contract.legalContractReference)
}
@Test
fun `verify that contract DummyContract is in classPath`() {
val contractClass = Class.forName("net.corda.nodeapi.AttachmentClassLoaderTests\$AttachmentDummyContract")
val contract = contractClass.newInstance() as Contract
assertNotNull(contract)
}
fun createContract2Cash(): Contract {
val cl = ClassLoaderForTests()
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl)
return contractClass.newInstance() as Contract
}
@Test
fun `testing Kryo with ClassLoader (with top level class name)`() {
val contract = createContract2Cash()
val bytes = contract.serialize()
val storage = MockAttachmentStorage()
val att0 = importJar(storage)
val att1 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file1.txt", "some data")))
val att2 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file2.txt", "some other data")))
val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader)
val context = P2P_CONTEXT.withClassLoader(cl).withWhitelisted(contract.javaClass)
val state2 = bytes.deserialize(context = context)
assertTrue(state2.javaClass.classLoader is AttachmentsClassLoader)
assertNotNull(state2)
}
// top level wrapper
@CordaSerializable
class Data(val contract: Contract)
@Test
fun `testing Kryo with ClassLoader (without top level class name)`() {
val data = Data(createContract2Cash())
assertNotNull(data.contract)
val context2 = P2P_CONTEXT.withWhitelisted(data.contract.javaClass)
val bytes = data.serialize(context = context2)
val storage = MockAttachmentStorage()
val att0 = importJar(storage)
val att1 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file1.txt", "some data")))
val att2 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file2.txt", "some other data")))
val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader)
val context = P2P_CONTEXT.withClassLoader(cl).withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl))
val state2 = bytes.deserialize(context = context)
assertEquals(cl, state2.contract.javaClass.classLoader)
assertNotNull(state2)
// We should be able to load same class from a different class loader and have them be distinct.
val cl2 = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader)
val context3 = P2P_CONTEXT.withClassLoader(cl2).withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl2))
val state3 = bytes.deserialize(context = context3)
assertEquals(cl2, state3.contract.javaClass.classLoader)
assertNotNull(state3)
}
@Test
fun `test serialization of WireTransaction with statically loaded contract`() {
val tx = ATTACHMENT_TEST_PROGRAM_ID.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY)
val wireTransaction = tx.toWireTransaction()
val bytes = wireTransaction.serialize()
val copiedWireTransaction = bytes.deserialize()
assertEquals(1, copiedWireTransaction.outputs.size)
assertEquals(42, (copiedWireTransaction.outputs[0].data as AttachmentDummyContract.State).magicNumber)
}
@Test
fun `test serialization of WireTransaction with dynamically loaded contract`() {
val child = ClassLoaderForTests()
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, child)
val contract = contractClass.newInstance() as DummyContractBackdoor
val tx = contract.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY)
val storage = MockAttachmentStorage()
val context = P2P_CONTEXT.withWhitelisted(contract.javaClass)
.withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$State", true, child))
.withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$Commands\$Create", true, child))
.withAttachmentStorage(storage)
// todo - think about better way to push attachmentStorage down to serializer
val bytes = run {
val attachmentRef = importJar(storage)
tx.addAttachment(storage.openAttachment(attachmentRef)!!.id)
val wireTransaction = tx.toWireTransaction()
wireTransaction.serialize(context = context)
}
val copiedWireTransaction = bytes.deserialize(context = context)
assertEquals(1, copiedWireTransaction.outputs.size)
val contract2 = copiedWireTransaction.getOutput(0).contract as DummyContractBackdoor
assertEquals(42, contract2.inspectState(copiedWireTransaction.outputs[0].data))
}
@Test
fun `test deserialize of WireTransaction where contract cannot be found`() {
val child = ClassLoaderForTests()
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, child)
val contract = contractClass.newInstance() as DummyContractBackdoor
val tx = contract.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY)
val storage = MockAttachmentStorage()
// todo - think about better way to push attachmentStorage down to serializer
val attachmentRef = importJar(storage)
val bytes = run {
tx.addAttachment(storage.openAttachment(attachmentRef)!!.id)
val wireTransaction = tx.toWireTransaction()
wireTransaction.serialize(context = P2P_CONTEXT.withAttachmentStorage(storage))
}
// use empty attachmentStorage
val e = assertFailsWith(MissingAttachmentsException::class) {
bytes.deserialize(context = P2P_CONTEXT.withAttachmentStorage(MockAttachmentStorage()))
}
assertEquals(attachmentRef, e.ids.single())
}
}

View File

@ -4,7 +4,7 @@ import com.typesafe.config.Config
import com.typesafe.config.ConfigFactory.empty
import com.typesafe.config.ConfigRenderOptions.defaults
import com.typesafe.config.ConfigValueFactory
import net.corda.core.div
import net.corda.core.internal.div
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.testing.getTestX509Name
import org.assertj.core.api.Assertions.assertThat

View File

@ -0,0 +1,254 @@
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.*
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.util.MapReferenceResolver
import net.corda.core.node.services.AttachmentStorage
import net.corda.core.serialization.*
import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.AttachmentClassLoaderTests
import net.corda.testing.node.MockAttachmentStorage
import org.junit.Rule
import org.junit.Test
import org.junit.rules.ExpectedException
import java.lang.IllegalStateException
import java.sql.Connection
import java.util.*
@CordaSerializable
enum class Foo {
Bar {
override val value = 0
},
Stick {
override val value = 1
};
abstract val value: Int
}
@CordaSerializable
open class Element
open class SubElement : Element()
class SubSubElement : SubElement()
abstract class AbstractClass
interface Interface
@CordaSerializable
interface SerializableInterface
interface SerializableSubInterface : SerializableInterface
class NotSerializable
class SerializableViaInterface : SerializableInterface
open class SerializableViaSubInterface : SerializableSubInterface
class SerializableViaSuperSubInterface : SerializableViaSubInterface()
@CordaSerializable
class CustomSerializable : KryoSerializable {
override fun read(kryo: Kryo?, input: Input?) {
}
override fun write(kryo: Kryo?, output: Output?) {
}
}
@CordaSerializable
@DefaultSerializer(DefaultSerializableSerializer::class)
class DefaultSerializable
class DefaultSerializableSerializer : Serializer<DefaultSerializable>() {
override fun write(kryo: Kryo, output: Output, obj: DefaultSerializable) {
}
override fun read(kryo: Kryo, input: Input, type: Class<DefaultSerializable>): DefaultSerializable {
return DefaultSerializable()
}
}
class CordaClassResolverTests {
val factory: SerializationFactory = object : SerializationFactory {
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
}
}
val emptyWhitelistContext: SerializationContext = SerializationContextImpl(KryoHeaderV0_1, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P)
val allButBlacklistedContext: SerializationContext = SerializationContextImpl(KryoHeaderV0_1, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P)
@Test
fun `Annotation on enum works for specialised entries`() {
// TODO: Remove this suppress when we upgrade to kotlin 1.1 or when JetBrain fixes the bug.
@Suppress("UNSUPPORTED_FEATURE")
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Foo.Bar::class.java)
}
@Test
fun `Annotation on array element works`() {
val values = arrayOf(Element())
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(values.javaClass)
}
@Test
fun `Annotation not needed on abstract class`() {
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(AbstractClass::class.java)
}
@Test
fun `Annotation not needed on interface`() {
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Interface::class.java)
}
@Test
fun `Calling register method on modified Kryo does not consult the whitelist`() {
val kryo = CordaKryo(CordaClassResolver(factory, emptyWhitelistContext))
kryo.register(NotSerializable::class.java)
}
@Test(expected = KryoException::class)
fun `Calling register method on unmodified Kryo does consult the whitelist`() {
val kryo = Kryo(CordaClassResolver(factory, emptyWhitelistContext), MapReferenceResolver())
kryo.register(NotSerializable::class.java)
}
@Test(expected = KryoException::class)
fun `Annotation is needed without whitelisting`() {
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(NotSerializable::class.java)
}
@Test
fun `Annotation is not needed with whitelisting`() {
val resolver = CordaClassResolver(factory, emptyWhitelistContext.withWhitelisted(NotSerializable::class.java))
resolver.getRegistration(NotSerializable::class.java)
}
@Test
fun `Annotation not needed on Object`() {
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Object::class.java)
}
@Test
fun `Annotation not needed on primitive`() {
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Integer.TYPE)
}
@Test(expected = KryoException::class)
fun `Annotation does not work for custom serializable`() {
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(CustomSerializable::class.java)
}
@Test(expected = KryoException::class)
fun `Annotation does not work in conjunction with Kryo annotation`() {
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(DefaultSerializable::class.java)
}
private fun importJar(storage: AttachmentStorage) = AttachmentClassLoaderTests.ISOLATED_CONTRACTS_JAR_PATH.openStream().use { storage.importAttachment(it) }
@Test(expected = KryoException::class)
fun `Annotation does not work in conjunction with AttachmentClassLoader annotation`() {
val storage = MockAttachmentStorage()
val attachmentHash = importJar(storage)
val classLoader = AttachmentsClassLoader(arrayOf(attachmentHash).map { storage.openAttachment(it)!! })
val attachedClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, classLoader)
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(attachedClass)
}
@Test
fun `Annotation is inherited from interfaces`() {
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SerializableViaInterface::class.java)
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SerializableViaSubInterface::class.java)
}
@Test
fun `Annotation is inherited from superclass`() {
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SubElement::class.java)
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SubSubElement::class.java)
CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SerializableViaSuperSubInterface::class.java)
}
// Blacklist tests.
@get:Rule
val expectedEx = ExpectedException.none()!!
@Test
fun `Check blacklisted class`() {
expectedEx.expect(IllegalStateException::class.java)
expectedEx.expectMessage("Class java.util.HashSet is blacklisted, so it cannot be used in serialization.")
val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// HashSet is blacklisted.
resolver.getRegistration(HashSet::class.java)
}
open class SubHashSet<E> : HashSet<E>()
@Test
fun `Check blacklisted subclass`() {
expectedEx.expect(IllegalStateException::class.java)
expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubHashSet is blacklisted, so it cannot be used in serialization.")
val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// SubHashSet extends the blacklisted HashSet.
resolver.getRegistration(SubHashSet::class.java)
}
class SubSubHashSet<E> : SubHashSet<E>()
@Test
fun `Check blacklisted subsubclass`() {
expectedEx.expect(IllegalStateException::class.java)
expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubSubHashSet is blacklisted, so it cannot be used in serialization.")
val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// SubSubHashSet extends SubHashSet, which extends the blacklisted HashSet.
resolver.getRegistration(SubSubHashSet::class.java)
}
class ConnectionImpl(val connection: Connection) : Connection by connection
@Test
fun `Check blacklisted interface impl`() {
expectedEx.expect(IllegalStateException::class.java)
expectedEx.expectMessage("The superinterface java.sql.Connection of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$ConnectionImpl is blacklisted, so it cannot be used in serialization.")
val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// ConnectionImpl implements blacklisted Connection.
resolver.getRegistration(ConnectionImpl::class.java)
}
interface SubConnection : Connection
class SubConnectionImpl(val subConnection: SubConnection) : SubConnection by subConnection
@Test
fun `Check blacklisted super-interface impl`() {
expectedEx.expect(IllegalStateException::class.java)
expectedEx.expectMessage("The superinterface java.sql.Connection of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubConnectionImpl is blacklisted, so it cannot be used in serialization.")
val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// SubConnectionImpl implements SubConnection, which extends the blacklisted Connection.
resolver.getRegistration(SubConnectionImpl::class.java)
}
@Test
fun `Check forcibly allowed`() {
val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// LinkedHashSet is allowed for serialization.
resolver.getRegistration(LinkedHashSet::class.java)
}
@CordaSerializable
class CordaSerializableHashSet<E> : HashSet<E>()
@Test
fun `Check blacklist precedes CordaSerializable`() {
expectedEx.expect(IllegalStateException::class.java)
expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$CordaSerializableHashSet is blacklisted, so it cannot be used in serialization.")
val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// CordaSerializableHashSet is @CordaSerializable, but extends the blacklisted HashSet.
resolver.getRegistration(CordaSerializableHashSet::class.java)
}
}

View File

@ -0,0 +1,212 @@
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoSerializable
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.google.common.primitives.Ints
import net.corda.core.crypto.*
import net.corda.core.serialization.*
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.sequence
import net.corda.node.serialization.KryoServerSerializationScheme
import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.testing.ALICE_PUBKEY
import net.corda.testing.TestDependencyInjectionBase
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Before
import org.junit.Test
import org.slf4j.LoggerFactory
import java.io.ByteArrayInputStream
import java.io.InputStream
import java.time.Instant
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class KryoTests : TestDependencyInjectionBase() {
private lateinit var factory: SerializationFactory
private lateinit var context: SerializationContext
@Before
fun setup() {
factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) }
context = SerializationContextImpl(KryoHeaderV0_1,
javaClass.classLoader,
AllWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.P2P)
}
@Test
fun ok() {
val birthday = Instant.parse("1984-04-17T00:30:00.00Z")
val mike = Person("mike", birthday)
val bits = mike.serialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(Person("mike", birthday))
}
@Test
fun nullables() {
val bob = Person("bob", null)
val bits = bob.serialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(Person("bob", null))
}
@Test
fun `serialised form is stable when the same object instance is added to the deserialised object graph`() {
val noReferencesContext = context.withoutReferences()
val obj = Ints.toByteArray(0x01234567).sequence()
val originalList = arrayListOf(obj)
val deserialisedList = originalList.serialize(factory, noReferencesContext).deserialize(factory, noReferencesContext)
originalList += obj
deserialisedList += obj
assertThat(deserialisedList.serialize(factory, noReferencesContext)).isEqualTo(originalList.serialize(factory, noReferencesContext))
}
@Test
fun `serialised form is stable when the same object instance occurs more than once, and using java serialisation`() {
val noReferencesContext = context.withoutReferences()
val instant = Instant.ofEpochMilli(123)
val instantCopy = Instant.ofEpochMilli(123)
assertThat(instant).isNotSameAs(instantCopy)
val listWithCopies = arrayListOf(instant, instantCopy)
val listWithSameInstances = arrayListOf(instant, instant)
assertThat(listWithSameInstances.serialize(factory, noReferencesContext)).isEqualTo(listWithCopies.serialize(factory, noReferencesContext))
}
@Test
fun `cyclic object graph`() {
val cyclic = Cyclic(3)
val bits = cyclic.serialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(cyclic)
}
@Test
fun `deserialised key pair functions the same as serialised one`() {
val keyPair = generateKeyPair()
val bitsToSign: ByteArray = Ints.toByteArray(0x01234567)
val wrongBits: ByteArray = Ints.toByteArray(0x76543210)
val signature = keyPair.sign(bitsToSign)
signature.verify(bitsToSign)
assertThatThrownBy { signature.verify(wrongBits) }
val deserialisedKeyPair = keyPair.serialize(factory, context).deserialize(factory, context)
val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign)
deserialisedSignature.verify(bitsToSign)
assertThatThrownBy { deserialisedSignature.verify(wrongBits) }
}
@Test
fun `write and read Kotlin object singleton`() {
val serialised = TestSingleton.serialize(factory, context)
val deserialised = serialised.deserialize(factory, context)
assertThat(deserialised).isSameAs(TestSingleton)
}
@Test
fun `InputStream serialisation`() {
val rubbish = ByteArray(12345, { (it * it * 0.12345).toByte() })
val readRubbishStream: InputStream = rubbish.inputStream().serialize(factory, context).deserialize(factory, context)
for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
assertEquals(-1, readRubbishStream.read())
}
@Test
fun `serialize - deserialize SignableData`() {
val testString = "Hello World"
val testBytes = testString.toByteArray()
val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID))
val serializedMetaData = meta.serialize(factory, context).bytes
val meta2 = serializedMetaData.deserialize<SignableData>(factory, context)
assertEquals(meta2, meta)
}
@Test
fun `serialize - deserialize Logger`() {
val storageContext: SerializationContext = context // TODO: make it storage context
val logger = LoggerFactory.getLogger("aName")
val logger2 = logger.serialize(factory, storageContext).deserialize(factory, storageContext)
assertEquals(logger.name, logger2.name)
assertTrue(logger === logger2)
}
@Test
fun `HashCheckingStream (de)serialize`() {
val rubbish = ByteArray(12345, { (it * it * 0.12345).toByte() })
val readRubbishStream: InputStream = NodeAttachmentService.HashCheckingStream(SecureHash.sha256(rubbish), rubbish.size, ByteArrayInputStream(rubbish)).serialize(factory, context).deserialize(factory, context)
for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
assertEquals(-1, readRubbishStream.read())
}
@CordaSerializable
private data class Person(val name: String, val birthday: Instant?)
@Suppress("unused")
@CordaSerializable
private class Cyclic(val value: Int) {
val thisInstance = this
override fun equals(other: Any?): Boolean = (this === other) || (other is Cyclic && this.value == other.value)
override fun hashCode(): Int = value.hashCode()
override fun toString(): String = "Cyclic($value)"
}
@CordaSerializable
private object TestSingleton
object SimpleSteps {
object ONE : ProgressTracker.Step("one")
object TWO : ProgressTracker.Step("two")
object THREE : ProgressTracker.Step("three")
object FOUR : ProgressTracker.Step("four")
fun tracker() = ProgressTracker(ONE, TWO, THREE, FOUR)
}
object ChildSteps {
object AYY : ProgressTracker.Step("ayy")
object BEE : ProgressTracker.Step("bee")
object SEA : ProgressTracker.Step("sea")
fun tracker() = ProgressTracker(AYY, BEE, SEA)
}
@Test
fun rxSubscriptionsAreNotSerialized() {
val pt: ProgressTracker = SimpleSteps.tracker()
val pt2: ProgressTracker = ChildSteps.tracker()
class Unserializable : KryoSerializable {
override fun write(kryo: Kryo?, output: Output?) = throw AssertionError("not called")
override fun read(kryo: Kryo?, input: Input?) = throw AssertionError("not called")
fun foo() {
println("bar")
}
}
pt.setChildProgressTracker(SimpleSteps.TWO, pt2)
class Tmp {
val unserializable = Unserializable()
init {
pt2.changes.subscribe { unserializable.foo() }
}
}
Tmp()
val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) }
val context = SerializationContextImpl(KryoHeaderV0_1,
javaClass.classLoader,
AllWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.P2P)
pt.serialize(factory, context)
}
}

View File

@ -0,0 +1,126 @@
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.io.Output
import com.nhaarman.mockito_kotlin.mock
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.*
import net.corda.core.utilities.OpaqueBytes
import net.corda.node.serialization.KryoServerSerializationScheme
import net.corda.testing.TestDependencyInjectionBase
import org.assertj.core.api.Assertions.assertThat
import org.junit.Before
import org.junit.Test
import java.io.ByteArrayOutputStream
class SerializationTokenTest : TestDependencyInjectionBase() {
lateinit var factory: SerializationFactory
lateinit var context: SerializationContext
@Before
fun setup() {
factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) }
context = SerializationContextImpl(KryoHeaderV0_1,
javaClass.classLoader,
AllWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.P2P)
}
// Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized
private class LargeTokenizable : SingletonSerializeAsToken() {
val bytes = OpaqueBytes(ByteArray(1024))
val numBytes: Int
get() = bytes.size
override fun hashCode() = bytes.size
override fun equals(other: Any?) = other is LargeTokenizable && other.bytes.size == this.bytes.size
}
private fun serializeAsTokenContext(toBeTokenized: Any) = SerializeAsTokenContextImpl(toBeTokenized, factory, context, mock<ServiceHub>())
@Test
fun `write token and read tokenizable`() {
val tokenizableBefore = LargeTokenizable()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.serialize(factory, testContext)
assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes)
val tokenizableAfter = serializedBytes.deserialize(factory, testContext)
assertThat(tokenizableAfter).isSameAs(tokenizableBefore)
}
private class UnitSerializeAsToken : SingletonSerializeAsToken()
@Test
fun `write and read singleton`() {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.serialize(factory, testContext)
val tokenizableAfter = serializedBytes.deserialize(factory, testContext)
assertThat(tokenizableAfter).isSameAs(tokenizableBefore)
}
@Test(expected = UnsupportedOperationException::class)
fun `new token encountered after context init`() {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(emptyList<Any>())
val testContext = this.context.withTokenContext(context)
tokenizableBefore.serialize(factory, testContext)
}
@Test(expected = UnsupportedOperationException::class)
fun `deserialize unregistered token`() {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(emptyList<Any>())
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList<Any>())).serialize(factory, testContext)
serializedBytes.deserialize(factory, testContext)
}
@Test(expected = KryoException::class)
fun `no context set`() {
val tokenizableBefore = UnitSerializeAsToken()
tokenizableBefore.serialize(factory, context)
}
@Test(expected = KryoException::class)
fun `deserialize non-token`() {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val kryo: Kryo = DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(factory, this.context)))
val stream = ByteArrayOutputStream()
Output(stream).use {
it.write(KryoHeaderV0_1.bytes)
kryo.writeClass(it, SingletonSerializeAsToken::class.java)
kryo.writeObject(it, emptyList<Any>())
}
val serializedBytes = SerializedBytes<Any>(stream.toByteArray())
serializedBytes.deserialize(factory, testContext)
}
private class WrongTypeSerializeAsToken : SerializeAsToken {
object UnitSerializationToken : SerializationToken {
override fun fromToken(context: SerializeAsTokenContext): Any = UnitSerializeAsToken()
}
override fun toToken(context: SerializeAsTokenContext): SerializationToken = UnitSerializationToken
}
@Test(expected = KryoException::class)
fun `token returns unexpected type`() {
val tokenizableBefore = WrongTypeSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.serialize(factory, testContext)
serializedBytes.deserialize(factory, testContext)
}
}

View File

@ -0,0 +1,13 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.apache.qpid.proton.codec.Data
class TestSerializationOutput(
private val verbose: Boolean,
serializerFactory: SerializerFactory = SerializerFactory()) : SerializationOutput(serializerFactory) {
override fun writeSchema(schema: Schema, data: Data) {
if (verbose) println(schema)
super.writeSchema(schema, data)
}
}

View File

@ -0,0 +1,46 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotEquals
import kotlin.test.assertTrue
class DeserializeAndReturnEnvelopeTests {
fun testName(): String = Thread.currentThread().stackTrace[2].methodName
@Suppress("NOTHING_TO_INLINE")
inline fun classTestName(clazz: String) = "${this.javaClass.name}\$${testName()}\$$clazz"
@Test
fun oneType() {
data class A(val a: Int, val b: String)
val a = A(10, "20")
val factory = SerializerFactory()
fun serialise(clazz: Any) = SerializationOutput(factory).serialize(clazz)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assertTrue(obj.obj is A)
assertEquals(1, obj.envelope.schema.types.size)
assertEquals(classTestName("A"), obj.envelope.schema.types.first().name)
}
@Test
fun twoTypes() {
data class A(val a: Int, val b: String)
data class B(val a: A, val b: Float)
val b = B(A(10, "20"), 30.0F)
val factory = SerializerFactory()
fun serialise(clazz: Any) = SerializationOutput(factory).serialize(clazz)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b))
assertTrue(obj.obj is B)
assertEquals(2, obj.envelope.schema.types.size)
assertNotEquals(null, obj.envelope.schema.types.find { it.name == classTestName("A") })
assertNotEquals(null, obj.envelope.schema.types.find { it.name == classTestName("B") })
}
}

View File

@ -0,0 +1,446 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.junit.Test
import kotlin.test.*
import net.corda.nodeapi.internal.serialization.carpenter.*
// These tests work by having the class carpenter build the classes we serialise and then deserialise. Because
// those classes don't exist within the system's Class Loader the deserialiser will be forced to carpent
// versions of them up using its own internal class carpenter (each carpenter houses it's own loader). This
// replicates the situation where a receiver doesn't have some or all elements of a schema present on it's classpath
class DeserializeNeedingCarpentrySimpleTypesTest {
companion object {
/**
* If you want to see the schema encoded into the envelope after serialisation change this to true
*/
private const val VERBOSE = false
}
val sf = SerializerFactory()
val sf2 = SerializerFactory()
@Test
fun singleInt() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"int" to NonNullableField(Integer::class.javaPrimitiveType!!)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(1))
val db = DeserializationInput(sf).deserialize(sb)
val db2 = DeserializationInput(sf2).deserialize(sb)
// despite being carpented, and thus not on the class path, we should've cached clazz
// inside the serialiser object and thus we should have created the same type
assertEquals (db::class.java, clazz)
assertNotEquals (db2::class.java, clazz)
assertNotEquals (db::class.java, db2::class.java)
assertEquals(1, db::class.java.getMethod("getInt").invoke(db))
assertEquals(1, db2::class.java.getMethod("getInt").invoke(db2))
}
@Test
fun singleIntNullable() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"int" to NullableField(Integer::class.java)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(1))
val db1 = DeserializationInput(sf).deserialize(sb)
val db2 = DeserializationInput(sf2).deserialize(sb)
assertEquals(clazz, db1::class.java)
assertNotEquals(clazz, db2::class.java)
assertEquals(1, db1::class.java.getMethod("getInt").invoke(db1))
assertEquals(1, db2::class.java.getMethod("getInt").invoke(db2))
}
@Test
fun singleIntNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"int" to NullableField(Integer::class.java)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null))
val db1 = DeserializationInput(sf).deserialize(sb)
val db2 = DeserializationInput(sf2).deserialize(sb)
assertEquals(clazz, db1::class.java)
assertNotEquals(clazz, db2::class.java)
assertEquals(null, db1::class.java.getMethod("getInt").invoke(db1))
assertEquals(null, db2::class.java.getMethod("getInt").invoke(db2))
}
@Test
fun singleChar() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"char" to NonNullableField(Character::class.javaPrimitiveType!!)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance('a'))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals('a', db::class.java.getMethod("getChar").invoke(db))
}
@Test
fun singleCharNullable() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"char" to NullableField(Character::class.javaObjectType)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance('a'))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals('a', db::class.java.getMethod("getChar").invoke(db))
}
@Test
fun singleCharNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"char" to NullableField(java.lang.Character::class.java)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(null, db::class.java.getMethod("getChar").invoke(db))
}
@Test
fun singleLong() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"long" to NonNullableField(Long::class.javaPrimitiveType!!)
)))
val l : Long = 1
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(l))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(l, (db::class.java.getMethod("getLong").invoke(db)))
}
@Test
fun singleLongNullable() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"long" to NullableField(Long::class.javaObjectType)
)))
val l : Long = 1
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(l))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(l, (db::class.java.getMethod("getLong").invoke(db)))
}
@Test
fun singleLongNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"long" to NullableField(Long::class.javaObjectType)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(null, (db::class.java.getMethod("getLong").invoke(db)))
}
@Test
fun singleBoolean() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"boolean" to NonNullableField(Boolean::class.javaPrimitiveType!!)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(true))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(true, db::class.java.getMethod("getBoolean").invoke(db))
}
@Test
fun singleBooleanNullable() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"boolean" to NullableField(Boolean::class.javaObjectType)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(true))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(true, db::class.java.getMethod("getBoolean").invoke(db))
}
@Test
fun singleBooleanNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"boolean" to NullableField(Boolean::class.javaObjectType)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(null, db::class.java.getMethod("getBoolean").invoke(db))
}
@Test
fun singleDouble() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"double" to NonNullableField(Double::class.javaPrimitiveType!!)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(10.0))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(10.0, db::class.java.getMethod("getDouble").invoke(db))
}
@Test
fun singleDoubleNullable() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"double" to NullableField(Double::class.javaObjectType)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(10.0))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(10.0, db::class.java.getMethod("getDouble").invoke(db))
}
@Test
fun singleDoubleNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"double" to NullableField(Double::class.javaObjectType)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(null, db::class.java.getMethod("getDouble").invoke(db))
}
@Test
fun singleShort() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"short" to NonNullableField(Short::class.javaPrimitiveType!!)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(3.toShort()))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(3.toShort(), db::class.java.getMethod("getShort").invoke(db))
}
@Test
fun singleShortNullable() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"short" to NullableField(Short::class.javaObjectType)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(3.toShort()))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(3.toShort(), db::class.java.getMethod("getShort").invoke(db))
}
@Test
fun singleShortNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"short" to NullableField(Short::class.javaObjectType)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(null, db::class.java.getMethod("getShort").invoke(db))
}
@Test
fun singleFloat() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"float" to NonNullableField(Float::class.javaPrimitiveType!!)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(10.0F))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(10.0F, db::class.java.getMethod("getFloat").invoke(db))
}
@Test
fun singleFloatNullable() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"float" to NullableField(Float::class.javaObjectType)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(10.0F))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(10.0F, db::class.java.getMethod("getFloat").invoke(db))
}
@Test
fun singleFloatNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"float" to NullableField(Float::class.javaObjectType)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(null, db::class.java.getMethod("getFloat").invoke(db))
}
@Test
fun singleByte() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"byte" to NonNullableField(Byte::class.javaPrimitiveType!!)
)))
val b : Byte = 0b0101
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(b))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(b, db::class.java.getMethod("getByte").invoke(db))
assertEquals(0b0101, (db::class.java.getMethod("getByte").invoke(db) as Byte))
}
@Test
fun singleByteNullable() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"byte" to NullableField(Byte::class.javaObjectType)
)))
val b : Byte = 0b0101
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(b))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(b, db::class.java.getMethod("getByte").invoke(db))
assertEquals(0b0101, (db::class.java.getMethod("getByte").invoke(db) as Byte))
}
@Test
fun singleByteNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema("single", mapOf(
"byte" to NullableField(Byte::class.javaObjectType)
)))
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null))
val db = DeserializationInput(sf2).deserialize(sb)
assertNotEquals(clazz, db::class.java)
assertEquals(null, db::class.java.getMethod("getByte").invoke(db))
}
@Test
fun simpleTypeKnownInterface() {
val clazz = ClassCarpenter().build (ClassSchema(
"oneType", mapOf("name" to NonNullableField(String::class.java)),
interfaces = listOf (I::class.java)))
val testVal = "Some Person"
val classInstance = clazz.constructors[0].newInstance(testVal)
val serialisedBytes = TestSerializationOutput(VERBOSE, sf).serialize(classInstance)
val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes)
assertNotEquals(clazz, deserializedObj::class.java)
assertTrue(deserializedObj is I)
assertEquals(testVal, (deserializedObj as I).getName())
}
@Test
fun manyTypes() {
val manyClass = ClassCarpenter().build (ClassSchema("many", mapOf(
"intA" to NonNullableField (Int::class.java),
"intB" to NullableField (Integer::class.java),
"intC" to NullableField (Integer::class.java),
"strA" to NonNullableField (String::class.java),
"strB" to NullableField (String::class.java),
"strC" to NullableField (String::class.java),
"charA" to NonNullableField (Char::class.java),
"charB" to NullableField (Character::class.javaObjectType),
"charC" to NullableField (Character::class.javaObjectType),
"shortA" to NonNullableField (Short::class.javaPrimitiveType!!),
"shortB" to NullableField (Short::class.javaObjectType),
"shortC" to NullableField (Short::class.javaObjectType),
"longA" to NonNullableField (Long::class.javaPrimitiveType!!),
"longB" to NullableField(Long::class.javaObjectType),
"longC" to NullableField(Long::class.javaObjectType),
"booleanA" to NonNullableField (Boolean::class.javaPrimitiveType!!),
"booleanB" to NullableField (Boolean::class.javaObjectType),
"booleanC" to NullableField (Boolean::class.javaObjectType),
"doubleA" to NonNullableField (Double::class.javaPrimitiveType!!),
"doubleB" to NullableField (Double::class.javaObjectType),
"doubleC" to NullableField (Double::class.javaObjectType),
"floatA" to NonNullableField (Float::class.javaPrimitiveType!!),
"floatB" to NullableField (Float::class.javaObjectType),
"floatC" to NullableField (Float::class.javaObjectType),
"byteA" to NonNullableField (Byte::class.javaPrimitiveType!!),
"byteB" to NullableField (Byte::class.javaObjectType),
"byteC" to NullableField (Byte::class.javaObjectType))))
val serialisedBytes = TestSerializationOutput(VERBOSE, sf).serialize(
manyClass.constructors.first().newInstance(
1, 2, null,
"a", "b", null,
'c', 'd', null,
3.toShort(), 4.toShort(), null,
100.toLong(), 200.toLong(), null,
true, false, null,
10.0, 20.0, null,
10.0F, 20.0F, null,
0b0101.toByte(), 0b1010.toByte(), null))
val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes)
assertNotEquals(manyClass, deserializedObj::class.java)
assertEquals(1, deserializedObj::class.java.getMethod("getIntA").invoke(deserializedObj))
assertEquals(2, deserializedObj::class.java.getMethod("getIntB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getIntC").invoke(deserializedObj))
assertEquals("a", deserializedObj::class.java.getMethod("getStrA").invoke(deserializedObj))
assertEquals("b", deserializedObj::class.java.getMethod("getStrB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getStrC").invoke(deserializedObj))
assertEquals('c', deserializedObj::class.java.getMethod("getCharA").invoke(deserializedObj))
assertEquals('d', deserializedObj::class.java.getMethod("getCharB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getCharC").invoke(deserializedObj))
assertEquals(3.toShort(), deserializedObj::class.java.getMethod("getShortA").invoke(deserializedObj))
assertEquals(4.toShort(), deserializedObj::class.java.getMethod("getShortB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getShortC").invoke(deserializedObj))
assertEquals(100.toLong(), deserializedObj::class.java.getMethod("getLongA").invoke(deserializedObj))
assertEquals(200.toLong(), deserializedObj::class.java.getMethod("getLongB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getLongC").invoke(deserializedObj))
assertEquals(true, deserializedObj::class.java.getMethod("getBooleanA").invoke(deserializedObj))
assertEquals(false, deserializedObj::class.java.getMethod("getBooleanB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getBooleanC").invoke(deserializedObj))
assertEquals(10.0, deserializedObj::class.java.getMethod("getDoubleA").invoke(deserializedObj))
assertEquals(20.0, deserializedObj::class.java.getMethod("getDoubleB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getDoubleC").invoke(deserializedObj))
assertEquals(10.0F, deserializedObj::class.java.getMethod("getFloatA").invoke(deserializedObj))
assertEquals(20.0F, deserializedObj::class.java.getMethod("getFloatB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getFloatC").invoke(deserializedObj))
assertEquals(0b0101.toByte(), deserializedObj::class.java.getMethod("getByteA").invoke(deserializedObj))
assertEquals(0b1010.toByte(), deserializedObj::class.java.getMethod("getByteB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getByteC").invoke(deserializedObj))
}
}

View File

@ -0,0 +1,240 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.junit.Test
import kotlin.test.*
import net.corda.nodeapi.internal.serialization.carpenter.*
interface I {
fun getName() : String
}
/**
* These tests work by having the class carpenter build the classes we serialise and then deserialise them
* within the context of a second serialiser factory. The second factory is required as the first, having
* been used to serialise the class, will have cached a copy of the class and will thus bypass the need
* to pull it out of the class loader.
*
* However, those classes don't exist within the system's Class Loader and thus the deserialiser will be forced
* to carpent versions of them up using its own internal class carpenter (each carpenter houses it's own loader). This
* replicates the situation where a receiver doesn't have some or all elements of a schema present on it's classpath
*/
class DeserializeNeedingCarpentryTests {
companion object {
/**
* If you want to see the schema encoded into the envelope after serialisation change this to true
*/
private const val VERBOSE = false
}
val sf1 = SerializerFactory()
val sf2 = SerializerFactory()
@Test
fun verySimpleType() {
val testVal = 10
val clazz = ClassCarpenter().build(ClassSchema("oneType", mapOf("a" to NonNullableField(Int::class.java))))
val classInstance = clazz.constructors[0].newInstance(testVal)
val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance)
val deserializedObj1 = DeserializationInput(sf1).deserialize(serialisedBytes)
assertEquals(clazz, deserializedObj1::class.java)
assertEquals (testVal, deserializedObj1::class.java.getMethod("getA").invoke(deserializedObj1))
val deserializedObj2 = DeserializationInput(sf1).deserialize(serialisedBytes)
assertEquals(clazz, deserializedObj2::class.java)
assertEquals(deserializedObj1::class.java, deserializedObj2::class.java)
assertEquals (testVal, deserializedObj2::class.java.getMethod("getA").invoke(deserializedObj2))
val deserializedObj3 = DeserializationInput(sf2).deserialize(serialisedBytes)
assertNotEquals(clazz, deserializedObj3::class.java)
assertNotEquals(deserializedObj1::class.java, deserializedObj3::class.java)
assertNotEquals(deserializedObj2::class.java, deserializedObj3::class.java)
assertEquals (testVal, deserializedObj3::class.java.getMethod("getA").invoke(deserializedObj3))
val deserializedObj4 = DeserializationInput(sf2).deserialize(serialisedBytes)
assertNotEquals(clazz, deserializedObj4::class.java)
assertNotEquals(deserializedObj1::class.java, deserializedObj4::class.java)
assertNotEquals(deserializedObj2::class.java, deserializedObj4::class.java)
assertEquals(deserializedObj3::class.java, deserializedObj4::class.java)
assertEquals (testVal, deserializedObj4::class.java.getMethod("getA").invoke(deserializedObj4))
}
@Test
fun repeatedTypesAreRecognised() {
val testValA = 10
val testValB = 20
val testValC = 20
val clazz = ClassCarpenter().build(ClassSchema("oneType", mapOf("a" to NonNullableField(Int::class.java))))
val concreteA = clazz.constructors[0].newInstance(testValA)
val concreteB = clazz.constructors[0].newInstance(testValB)
val concreteC = clazz.constructors[0].newInstance(testValC)
val deserialisedA = DeserializationInput(sf2).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(concreteA))
assertEquals (testValA, deserialisedA::class.java.getMethod("getA").invoke(deserialisedA))
val deserialisedB = DeserializationInput(sf2).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(concreteB))
assertEquals (testValB, deserialisedA::class.java.getMethod("getA").invoke(deserialisedB))
assertEquals (deserialisedA::class.java, deserialisedB::class.java)
// C is deseriliased with a different factory, meaning a different class carpenter, so the type
// won't already exist and it will be carpented a second time showing that when A and B are the
// same underlying class that we didn't create a second instance of the class with the
// second deserialisation
val lsf = SerializerFactory()
val deserialisedC = DeserializationInput(lsf).deserialize(TestSerializationOutput(VERBOSE, lsf).serialize(concreteC))
assertEquals (testValC, deserialisedC::class.java.getMethod("getA").invoke(deserialisedC))
assertNotEquals (deserialisedA::class.java, deserialisedC::class.java)
assertNotEquals (deserialisedB::class.java, deserialisedC::class.java)
}
@Test
fun simpleTypeKnownInterface() {
val clazz = ClassCarpenter().build (ClassSchema(
"oneType", mapOf("name" to NonNullableField(String::class.java)),
interfaces = listOf (I::class.java)))
val testVal = "Some Person"
val classInstance = clazz.constructors[0].newInstance(testVal)
val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance)
val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes)
assertTrue(deserializedObj is I)
assertEquals(testVal, (deserializedObj as I).getName())
}
@Test
fun arrayOfTypes() {
val clazz = ClassCarpenter().build(ClassSchema("oneType", mapOf("a" to NonNullableField(Int::class.java))))
data class Outer (val a : Array<Any>)
val outer = Outer (arrayOf (
clazz.constructors[0].newInstance(1),
clazz.constructors[0].newInstance(2),
clazz.constructors[0].newInstance(3)))
val deserializedObj = DeserializationInput(sf2).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(outer))
assertNotEquals((deserializedObj.a[0])::class.java, (outer.a[0])::class.java)
assertNotEquals((deserializedObj.a[1])::class.java, (outer.a[1])::class.java)
assertNotEquals((deserializedObj.a[2])::class.java, (outer.a[2])::class.java)
assertEquals((deserializedObj.a[0])::class.java, (deserializedObj.a[1])::class.java)
assertEquals((deserializedObj.a[0])::class.java, (deserializedObj.a[2])::class.java)
assertEquals((deserializedObj.a[1])::class.java, (deserializedObj.a[2])::class.java)
assertEquals(
outer.a[0]::class.java.getMethod("getA").invoke(outer.a[0]),
deserializedObj.a[0]::class.java.getMethod("getA").invoke(deserializedObj.a[0]))
assertEquals(
outer.a[1]::class.java.getMethod("getA").invoke(outer.a[1]),
deserializedObj.a[1]::class.java.getMethod("getA").invoke(deserializedObj.a[1]))
assertEquals(
outer.a[2]::class.java.getMethod("getA").invoke(outer.a[2]),
deserializedObj.a[2]::class.java.getMethod("getA").invoke(deserializedObj.a[2]))
}
@Test
fun reusedClasses() {
val cc = ClassCarpenter()
val innerType = cc.build(ClassSchema("inner", mapOf("a" to NonNullableField(Int::class.java))))
val outerType = cc.build(ClassSchema("outer", mapOf("a" to NonNullableField(innerType))))
val inner = innerType.constructors[0].newInstance(1)
val outer = outerType.constructors[0].newInstance(innerType.constructors[0].newInstance(2))
val serializedI = TestSerializationOutput(VERBOSE, sf1).serialize(inner)
val deserialisedI = DeserializationInput(sf2).deserialize(serializedI)
val serialisedO = TestSerializationOutput(VERBOSE, sf1).serialize(outer)
val deserialisedO = DeserializationInput(sf2).deserialize(serialisedO)
// ensure out carpented version of inner is reused
assertEquals (deserialisedI::class.java,
(deserialisedO::class.java.getMethod("getA").invoke(deserialisedO))::class.java)
}
@Test
fun nestedTypes() {
val cc = ClassCarpenter()
val nestedClass = cc.build (ClassSchema("nestedType",
mapOf("name" to NonNullableField(String::class.java))))
val outerClass = cc.build (ClassSchema("outerType",
mapOf("inner" to NonNullableField(nestedClass))))
val classInstance = outerClass.constructors.first().newInstance(nestedClass.constructors.first().newInstance("name"))
val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance)
val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes)
val inner = deserializedObj::class.java.getMethod("getInner").invoke(deserializedObj)
assertEquals("name", inner::class.java.getMethod("getName").invoke(inner))
}
@Test
fun repeatedNestedTypes() {
val cc = ClassCarpenter()
val nestedClass = cc.build (ClassSchema("nestedType",
mapOf("name" to NonNullableField(String::class.java))))
data class outer(val a: Any, val b: Any)
val classInstance = outer (
nestedClass.constructors.first().newInstance("foo"),
nestedClass.constructors.first().newInstance("bar"))
val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance)
val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes)
assertEquals ("foo", deserializedObj.a::class.java.getMethod("getName").invoke(deserializedObj.a))
assertEquals ("bar", deserializedObj.b::class.java.getMethod("getName").invoke(deserializedObj.b))
}
@Test
fun listOfType() {
val unknownClass = ClassCarpenter().build (ClassSchema("unknownClass", mapOf(
"v1" to NonNullableField(Int::class.java),
"v2" to NonNullableField(Int::class.java))))
data class outer (val l : List<Any>)
val toSerialise = outer (listOf (
unknownClass.constructors.first().newInstance(1, 2),
unknownClass.constructors.first().newInstance(3, 4),
unknownClass.constructors.first().newInstance(5, 6),
unknownClass.constructors.first().newInstance(7, 8)))
val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(toSerialise)
val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes)
var sentinel = 1
deserializedObj.l.forEach {
assertEquals(sentinel++, it::class.java.getMethod("getV1").invoke(it))
assertEquals(sentinel++, it::class.java.getMethod("getV2").invoke(it))
}
}
@Test
fun unknownInterface() {
val cc = ClassCarpenter()
val interfaceClass = cc.build (InterfaceSchema(
"gen.Interface",
mapOf("age" to NonNullableField (Int::class.java))))
val concreteClass = cc.build (ClassSchema ("gen.Class", mapOf(
"age" to NonNullableField (Int::class.java),
"name" to NonNullableField(String::class.java)),
interfaces = listOf (I::class.java, interfaceClass)))
val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(
concreteClass.constructors.first().newInstance(12, "timmy"))
val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes)
assertTrue(deserializedObj is I)
assertEquals("timmy", (deserializedObj as I).getName())
assertEquals("timmy", deserializedObj::class.java.getMethod("getName").invoke(deserializedObj))
assertEquals(12, deserializedObj::class.java.getMethod("getAge").invoke(deserializedObj))
}
}

View File

@ -0,0 +1,484 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.junit.Test
import kotlin.test.assertEquals
// Prior to certain fixes being made within the [PropertySerializaer] classes these simple
// deserialization operations would've blown up with type mismatch errors where the deserlized
// char property of the class would've been treated as an Integer and given to the constructor
// as such
class DeserializeSimpleTypesTests {
companion object {
/**
* If you want to see the schema encoded into the envelope after serialisation change this to true
*/
private const val VERBOSE = false
}
val sf1 = SerializerFactory()
val sf2 = SerializerFactory()
@Test
fun testChar() {
data class C(val c: Char)
var deserializedC = DeserializationInput().deserialize(SerializationOutput().serialize(C('c')))
assertEquals('c', deserializedC.c)
// CYRILLIC CAPITAL LETTER YU (U+042E)
deserializedC = DeserializationInput().deserialize(SerializationOutput().serialize(C('Ю')))
assertEquals('Ю', deserializedC.c)
// ARABIC LETTER FEH WITH DOT BELOW (U+06A3)
deserializedC = DeserializationInput().deserialize(SerializationOutput().serialize(C('ڣ')))
assertEquals('ڣ', deserializedC.c)
// ARABIC LETTER DAD WITH DOT BELOW (U+06FB)
deserializedC = DeserializationInput().deserialize(SerializationOutput().serialize(C('ۻ')))
assertEquals('ۻ', deserializedC.c)
// BENGALI LETTER AA (U+0986)
deserializedC = DeserializationInput().deserialize(SerializationOutput().serialize(C('আ')))
assertEquals('আ', deserializedC.c)
}
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
@Test
fun testCharacter() {
data class C(val c: Character)
val c = C(Character('c'))
val serialisedC = SerializationOutput().serialize(c)
val deserializedC = DeserializationInput().deserialize(serialisedC)
assertEquals(c.c, deserializedC.c)
}
@Test
fun testNullCharacter() {
data class C(val c: Char?)
val c = C(null)
val serialisedC = SerializationOutput().serialize(c)
val deserializedC = DeserializationInput().deserialize(serialisedC)
assertEquals(c.c, deserializedC.c)
}
@Test
fun testArrayOfInt() {
class IA(val ia: Array<Int>)
val ia = IA(arrayOf(1, 2, 3))
assertEquals("class [Ljava.lang.Integer;", ia.ia::class.java.toString())
assertEquals(SerializerFactory.nameForType(ia.ia::class.java), "int[]")
val serialisedIA = TestSerializationOutput(VERBOSE, sf1).serialize(ia)
val deserializedIA = DeserializationInput(sf1).deserialize(serialisedIA)
assertEquals(ia.ia.size, deserializedIA.ia.size)
assertEquals(ia.ia[0], deserializedIA.ia[0])
assertEquals(ia.ia[1], deserializedIA.ia[1])
assertEquals(ia.ia[2], deserializedIA.ia[2])
}
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
@Test
fun testArrayOfInteger() {
class IA(val ia: Array<Integer>)
val ia = IA(arrayOf(Integer(1), Integer(2), Integer(3)))
assertEquals("class [Ljava.lang.Integer;", ia.ia::class.java.toString())
assertEquals(SerializerFactory.nameForType(ia.ia::class.java), "int[]")
val serialisedIA = TestSerializationOutput(VERBOSE, sf1).serialize(ia)
val deserializedIA = DeserializationInput(sf1).deserialize(serialisedIA)
assertEquals(ia.ia.size, deserializedIA.ia.size)
assertEquals(ia.ia[0], deserializedIA.ia[0])
assertEquals(ia.ia[1], deserializedIA.ia[1])
assertEquals(ia.ia[2], deserializedIA.ia[2])
}
/**
* Test unboxed primitives
*/
@Test
fun testIntArray() {
class IA(val ia: IntArray)
val v = IntArray(3)
v[0] = 1; v[1] = 2; v[2] = 3
val ia = IA(v)
assertEquals("class [I", ia.ia::class.java.toString())
assertEquals(SerializerFactory.nameForType(ia.ia::class.java), "int[p]")
val serialisedIA = TestSerializationOutput(VERBOSE, sf1).serialize(ia)
val deserializedIA = DeserializationInput(sf1).deserialize(serialisedIA)
assertEquals(ia.ia.size, deserializedIA.ia.size)
assertEquals(ia.ia[0], deserializedIA.ia[0])
assertEquals(ia.ia[1], deserializedIA.ia[1])
assertEquals(ia.ia[2], deserializedIA.ia[2])
}
@Test
fun testArrayOfChars() {
class C(val c: Array<Char>)
val c = C(arrayOf('a', 'b', 'c'))
assertEquals("class [Ljava.lang.Character;", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "char[]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
}
@Test
fun testCharArray() {
class C(val c: CharArray)
val v = CharArray(3)
v[0] = 'a'; v[1] = 'b'; v[2] = 'c'
val c = C(v)
assertEquals("class [C", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "char[p]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
var deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
// second test with more interesting characters
v[0] = 'ই'; v[1] = ' '; v[2] = 'ਔ'
val c2 = C(v)
deserializedC = DeserializationInput(sf1).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(c2))
assertEquals(c2.c.size, deserializedC.c.size)
assertEquals(c2.c[0], deserializedC.c[0])
assertEquals(c2.c[1], deserializedC.c[1])
assertEquals(c2.c[2], deserializedC.c[2])
}
@Test
fun testArrayOfBoolean() {
class C(val c: Array<Boolean>)
val c = C(arrayOf(true, false, false, true))
assertEquals("class [Ljava.lang.Boolean;", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "boolean[]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
assertEquals(c.c[3], deserializedC.c[3])
}
@Test
fun testBooleanArray() {
class C(val c: BooleanArray)
val c = C(BooleanArray(4))
c.c[0] = true; c.c[1] = false; c.c[2] = false; c.c[3] = true
assertEquals("class [Z", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "boolean[p]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
assertEquals(c.c[3], deserializedC.c[3])
}
@Test
fun testArrayOfByte() {
class C(val c: Array<Byte>)
val c = C(arrayOf(0b0001, 0b0101, 0b1111))
assertEquals("class [Ljava.lang.Byte;", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "byte[]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
}
@Test
fun testByteArray() {
class C(val c: ByteArray)
val c = C(ByteArray(3))
c.c[0] = 0b0001; c.c[1] = 0b0101; c.c[2] = 0b1111
assertEquals("class [B", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "binary")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
}
@Test
fun testArrayOfShort() {
class C(val c: Array<Short>)
val c = C(arrayOf(1, 2, 3))
assertEquals("class [Ljava.lang.Short;", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "short[]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
}
@Test
fun testShortArray() {
class C(val c: ShortArray)
val c = C(ShortArray(3))
c.c[0] = 1; c.c[1] = 2; c.c[2] = 5
assertEquals("class [S", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "short[p]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
}
@Test
fun testArrayOfLong() {
class C(val c: Array<Long>)
val c = C(arrayOf(2147483650, -2147483800, 10))
assertEquals("class [Ljava.lang.Long;", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "long[]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
}
@Test
fun testLongArray() {
class C(val c: LongArray)
val c = C(LongArray(3))
c.c[0] = 2147483650; c.c[1] = -2147483800; c.c[2] = 10
assertEquals("class [J", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "long[p]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
}
@Test
fun testArrayOfFloat() {
class C(val c: Array<Float>)
val c = C(arrayOf(10F, 100.023232F, -1455.433400F))
assertEquals("class [Ljava.lang.Float;", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "float[]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
}
@Test
fun testFloatArray() {
class C(val c: FloatArray)
val c = C(FloatArray(3))
c.c[0] = 10F; c.c[1] = 100.023232F; c.c[2] = -1455.433400F
assertEquals("class [F", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "float[p]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
}
@Test
fun testArrayOfDouble() {
class C(val c: Array<Double>)
val c = C(arrayOf(10.0, 100.2, -1455.2))
assertEquals("class [Ljava.lang.Double;", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "double[]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
}
@Test
fun testDoubleArray() {
class C(val c: DoubleArray)
val c = C(DoubleArray(3))
c.c[0] = 10.0; c.c[1] = 100.2; c.c[2] = -1455.2
assertEquals("class [D", c.c::class.java.toString())
assertEquals(SerializerFactory.nameForType(c.c::class.java), "double[p]")
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0], deserializedC.c[0])
assertEquals(c.c[1], deserializedC.c[1])
assertEquals(c.c[2], deserializedC.c[2])
}
@Test
fun arrayOfArrayOfInt() {
class C(val c: Array<Array<Int>>)
val c = C (arrayOf (arrayOf(1,2,3), arrayOf(4,5,6)))
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0].size, deserializedC.c[0].size)
assertEquals(c.c[0][0], deserializedC.c[0][0])
assertEquals(c.c[0][1], deserializedC.c[0][1])
assertEquals(c.c[0][2], deserializedC.c[0][2])
assertEquals(c.c[1].size, deserializedC.c[1].size)
assertEquals(c.c[1][0], deserializedC.c[1][0])
assertEquals(c.c[1][1], deserializedC.c[1][1])
assertEquals(c.c[1][2], deserializedC.c[1][2])
}
@Test
fun arrayOfIntArray() {
class C(val c: Array<IntArray>)
val c = C (arrayOf (IntArray(3), IntArray(3)))
c.c[0][0] = 1; c.c[0][1] = 2; c.c[0][2] = 3
c.c[1][0] = 4; c.c[1][1] = 5; c.c[1][2] = 6
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
assertEquals(c.c.size, deserializedC.c.size)
assertEquals(c.c[0].size, deserializedC.c[0].size)
assertEquals(c.c[0][0], deserializedC.c[0][0])
assertEquals(c.c[0][1], deserializedC.c[0][1])
assertEquals(c.c[0][2], deserializedC.c[0][2])
assertEquals(c.c[1].size, deserializedC.c[1].size)
assertEquals(c.c[1][0], deserializedC.c[1][0])
assertEquals(c.c[1][1], deserializedC.c[1][1])
assertEquals(c.c[1][2], deserializedC.c[1][2])
}
@Test
fun arrayOfArrayOfIntArray() {
class C(val c: Array<Array<IntArray>>)
val c = C(arrayOf(arrayOf(IntArray(3), IntArray(3), IntArray(3)),
arrayOf(IntArray(3), IntArray(3), IntArray(3)),
arrayOf(IntArray(3), IntArray(3), IntArray(3))))
for (i in 0..2) { for (j in 0..2) { for (k in 0..2) { c.c[i][j][k] = i + j + k } } }
val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c)
val deserializedC = DeserializationInput(sf1).deserialize(serialisedC)
for (i in 0..2) { for (j in 0..2) { for (k in 0..2) {
assertEquals(c.c[i][j][k], deserializedC.c[i][j][k])
}}}
}
@Test
fun nestedRepeatedTypes() {
class A(val a : A?, val b: Int)
var a = A(A(A(A(A(null, 1), 2), 3), 4), 5)
val sa = TestSerializationOutput(VERBOSE, sf1).serialize(a)
val da1 = DeserializationInput(sf1).deserialize(sa)
val da2 = DeserializationInput(sf2).deserialize(sa)
assertEquals(5, da1.b)
assertEquals(4, da1.a?.b)
assertEquals(3, da1.a?.a?.b)
assertEquals(2, da1.a?.a?.a?.b)
assertEquals(1, da1.a?.a?.a?.a?.b)
assertEquals(5, da2.b)
assertEquals(4, da2.a?.b)
assertEquals(3, da2.a?.a?.b)
assertEquals(2, da2.a?.a?.a?.b)
assertEquals(1, da2.a?.a?.a?.a?.b)
}
}

View File

@ -0,0 +1,100 @@
package net.corda.nodeapi.internal.serialization.amqp
import org.junit.Test
import java.io.NotSerializableException
import kotlin.test.assertEquals
class DeserializedParameterizedTypeTests {
private fun normalise(string: String): String {
return string.replace(" ", "")
}
private fun verify(typeName: String) {
val type = DeserializedParameterizedType.make(typeName)
assertEquals(normalise(type.typeName), normalise(typeName))
}
@Test
fun `test nested`() {
verify(" java.util.Map < java.util.Map< java.lang.String, java.lang.Integer >, java.util.Map < java.lang.Long , java.lang.String > >")
}
@Test
fun `test simple`() {
verify("java.util.List<java.lang.String>")
}
@Test
fun `test multiple args`() {
verify("java.util.Map<java.lang.String,java.lang.Integer>")
}
@Test
fun `test trailing whitespace`() {
verify("java.util.Map<java.lang.String, java.lang.Integer> ")
}
@Test(expected = NotSerializableException::class)
fun `test trailing text`() {
verify("java.util.Map<java.lang.String, java.lang.Integer>foo")
}
@Test(expected = NotSerializableException::class)
fun `test trailing comma`() {
verify("java.util.Map<java.lang.String, java.lang.Integer,>")
}
@Test(expected = NotSerializableException::class)
fun `test leading comma`() {
verify("java.util.Map<,java.lang.String, java.lang.Integer>")
}
@Test(expected = NotSerializableException::class)
fun `test middle comma`() {
verify("java.util.Map<,java.lang.String,, java.lang.Integer>")
}
@Test(expected = NotSerializableException::class)
fun `test trailing close`() {
verify("java.util.Map<java.lang.String, java.lang.Integer>>")
}
@Test(expected = NotSerializableException::class)
fun `test empty params`() {
verify("java.util.Map<>")
}
@Test(expected = NotSerializableException::class)
fun `test mid whitespace`() {
verify("java.u til.List<java.lang.String>")
}
@Test(expected = NotSerializableException::class)
fun `test mid whitespace2`() {
verify("java.util.List<java.l ng.String>")
}
@Test(expected = NotSerializableException::class)
fun `test wrong number of parameters`() {
verify("java.util.List<java.lang.String, java.lang.Integer>")
}
@Test
fun `test no parameters`() {
verify("java.lang.String")
}
@Test(expected = NotSerializableException::class)
fun `test parameters on non-generic type`() {
verify("java.lang.String<java.lang.Integer>")
}
@Test(expected = NotSerializableException::class)
fun `test excessive nesting`() {
var nested = "java.lang.Integer"
for (i in 1..DeserializedParameterizedType.MAX_DEPTH) {
nested = "java.util.List<$nested>"
}
verify(nested)
}
}

View File

@ -0,0 +1,226 @@
package net.corda.nodeapi.internal.serialization.amqp;
import net.corda.core.serialization.SerializedBytes;
import org.apache.qpid.proton.codec.DecoderImpl;
import org.apache.qpid.proton.codec.EncoderImpl;
import org.junit.Test;
import javax.annotation.Nonnull;
import java.io.NotSerializableException;
import java.nio.ByteBuffer;
import java.util.Objects;
import static org.junit.Assert.assertTrue;
public class JavaSerializationOutputTests {
static class Foo {
private final String bob;
private final int count;
public Foo(String msg, long count) {
this.bob = msg;
this.count = (int) count;
}
@ConstructorForDeserialization
public Foo(String fred, int count) {
this.bob = fred;
this.count = count;
}
public String getFred() {
return bob;
}
public int getCount() {
return count;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Foo foo = (Foo) o;
if (count != foo.count) return false;
return bob != null ? bob.equals(foo.bob) : foo.bob == null;
}
@Override
public int hashCode() {
int result = bob != null ? bob.hashCode() : 0;
result = 31 * result + count;
return result;
}
}
static class UnAnnotatedFoo {
private final String bob;
private final int count;
public UnAnnotatedFoo(String fred, int count) {
this.bob = fred;
this.count = count;
}
public String getFred() {
return bob;
}
public int getCount() {
return count;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
UnAnnotatedFoo foo = (UnAnnotatedFoo) o;
if (count != foo.count) return false;
return bob != null ? bob.equals(foo.bob) : foo.bob == null;
}
@Override
public int hashCode() {
int result = bob != null ? bob.hashCode() : 0;
result = 31 * result + count;
return result;
}
}
static class BoxedFoo {
private final String fred;
private final Integer count;
public BoxedFoo(String fred, Integer count) {
this.fred = fred;
this.count = count;
}
public String getFred() {
return fred;
}
public Integer getCount() {
return count;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BoxedFoo boxedFoo = (BoxedFoo) o;
if (fred != null ? !fred.equals(boxedFoo.fred) : boxedFoo.fred != null) return false;
return count != null ? count.equals(boxedFoo.count) : boxedFoo.count == null;
}
@Override
public int hashCode() {
int result = fred != null ? fred.hashCode() : 0;
result = 31 * result + (count != null ? count.hashCode() : 0);
return result;
}
}
static class BoxedFooNotNull {
private final String fred;
private final Integer count;
public BoxedFooNotNull(String fred, Integer count) {
this.fred = fred;
this.count = count;
}
public String getFred() {
return fred;
}
@Nonnull
public Integer getCount() {
return count;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BoxedFooNotNull boxedFoo = (BoxedFooNotNull) o;
if (fred != null ? !fred.equals(boxedFoo.fred) : boxedFoo.fred != null) return false;
return count != null ? count.equals(boxedFoo.count) : boxedFoo.count == null;
}
@Override
public int hashCode() {
int result = fred != null ? fred.hashCode() : 0;
result = 31 * result + (count != null ? count.hashCode() : 0);
return result;
}
}
private Object serdes(Object obj) throws NotSerializableException {
SerializerFactory factory = new SerializerFactory();
SerializationOutput ser = new SerializationOutput(factory);
SerializedBytes<Object> bytes = ser.serialize(obj);
DecoderImpl decoder = new DecoderImpl();
decoder.register(Envelope.Companion.getDESCRIPTOR(), Envelope.Companion);
decoder.register(Schema.Companion.getDESCRIPTOR(), Schema.Companion);
decoder.register(Descriptor.Companion.getDESCRIPTOR(), Descriptor.Companion);
decoder.register(Field.Companion.getDESCRIPTOR(), Field.Companion);
decoder.register(CompositeType.Companion.getDESCRIPTOR(), CompositeType.Companion);
decoder.register(Choice.Companion.getDESCRIPTOR(), Choice.Companion);
decoder.register(RestrictedType.Companion.getDESCRIPTOR(), RestrictedType.Companion);
new EncoderImpl(decoder);
decoder.setByteBuffer(ByteBuffer.wrap(bytes.getBytes(), 8, bytes.getSize() - 8));
Envelope result = (Envelope) decoder.readObject();
assertTrue(result != null);
DeserializationInput des = new DeserializationInput();
Object desObj = des.deserialize(bytes, Object.class);
assertTrue(Objects.deepEquals(obj, desObj));
// Now repeat with a re-used factory
SerializationOutput ser2 = new SerializationOutput(factory);
DeserializationInput des2 = new DeserializationInput(factory);
Object desObj2 = des2.deserialize(ser2.serialize(obj), Object.class);
assertTrue(Objects.deepEquals(obj, desObj2));
// TODO: check schema is as expected
return desObj2;
}
@Test
public void testJavaConstructorAnnotations() throws NotSerializableException {
Foo obj = new Foo("Hello World!", 123);
serdes(obj);
}
@Test
public void testJavaConstructorWithoutAnnotations() throws NotSerializableException {
UnAnnotatedFoo obj = new UnAnnotatedFoo("Hello World!", 123);
serdes(obj);
}
@Test
public void testBoxedTypes() throws NotSerializableException {
BoxedFoo obj = new BoxedFoo("Hello World!", 123);
serdes(obj);
}
@Test
public void testBoxedTypesNotNull() throws NotSerializableException {
BoxedFooNotNull obj = new BoxedFooNotNull("Hello World!", 123);
serdes(obj);
}
}

View File

@ -0,0 +1,590 @@
package net.corda.nodeapi.internal.serialization.amqp
import net.corda.core.CordaRuntimeException
import net.corda.core.contracts.Contract
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionState
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowException
import net.corda.core.identity.AbstractParty
import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.LedgerTransaction
import net.corda.nodeapi.RPCException
import net.corda.nodeapi.internal.serialization.AbstractAMQPSerializationScheme
import net.corda.nodeapi.internal.serialization.EmptyWhitelist
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.isPrimitive
import net.corda.nodeapi.internal.serialization.amqp.custom.*
import net.corda.testing.MEGA_CORP
import net.corda.testing.MEGA_CORP_PUBKEY
import org.apache.qpid.proton.amqp.*
import org.apache.qpid.proton.codec.DecoderImpl
import org.apache.qpid.proton.codec.EncoderImpl
import org.junit.Test
import java.io.IOException
import java.io.NotSerializableException
import java.math.BigDecimal
import java.nio.ByteBuffer
import java.time.Instant
import java.util.*
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
class SerializationOutputTests {
data class Foo(val bar: String, val pub: Int)
data class testFloat(val f: Float)
data class testDouble(val d: Double)
data class testShort(val s: Short)
data class testBoolean(val b : Boolean)
interface FooInterface {
val pub: Int
}
data class FooImplements(val bar: String, override val pub: Int) : FooInterface
data class FooImplementsAndList(val bar: String, override val pub: Int, val names: List<String>) : FooInterface
data class WrapHashMap(val map: Map<String, String>)
data class WrapFooListArray(val listArray: Array<List<Foo>>) {
override fun equals(other: Any?): Boolean {
return other is WrapFooListArray && Objects.deepEquals(listArray, other.listArray)
}
override fun hashCode(): Int {
return 1 // This isn't used, but without overriding we get a warning.
}
}
data class Woo(val fred: Int) {
@Suppress("unused")
val bob = "Bob"
}
data class Woo2(val fred: Int, val bob: String = "Bob") {
@ConstructorForDeserialization constructor(fred: Int) : this(fred, "Ginger")
}
@CordaSerializable
data class AnnotatedWoo(val fred: Int) {
@Suppress("unused")
val bob = "Bob"
}
class FooList : ArrayList<Foo>()
@Suppress("AddVarianceModifier")
data class GenericFoo<T>(val bar: String, val pub: T)
data class ContainsGenericFoo(val contain: GenericFoo<String>)
data class NestedGenericFoo<T>(val contain: GenericFoo<T>)
data class ContainsNestedGenericFoo(val contain: NestedGenericFoo<String>)
data class TreeMapWrapper(val tree: TreeMap<Int, Foo>)
data class NavigableMapWrapper(val tree: NavigableMap<Int, Foo>)
data class SortedSetWrapper(val set: SortedSet<Int>)
open class InheritedGeneric<X>(val foo: X)
data class ExtendsGeneric(val bar: Int, val pub: String) : InheritedGeneric<String>(pub)
interface GenericInterface<X> {
val pub: X
}
data class ImplementsGenericString(val bar: Int, override val pub: String) : GenericInterface<String>
data class ImplementsGenericX<Y>(val bar: Int, override val pub: Y) : GenericInterface<Y>
abstract class AbstractGenericX<Z> : GenericInterface<Z>
data class InheritGenericX<A>(val duke: Double, override val pub: A) : AbstractGenericX<A>()
data class CapturesGenericX(val foo: GenericInterface<String>)
object KotlinObject
class Mismatch(fred: Int) {
val ginger: Int = fred
override fun equals(other: Any?): Boolean = (other as? Mismatch)?.ginger == ginger
override fun hashCode(): Int = ginger
}
class MismatchType(fred: Long) {
val ginger: Int = fred.toInt()
override fun equals(other: Any?): Boolean = (other as? MismatchType)?.ginger == ginger
override fun hashCode(): Int = ginger
}
@CordaSerializable
interface AnnotatedInterface
data class InheritAnnotation(val foo: String) : AnnotatedInterface
data class PolymorphicProperty(val foo: FooInterface?)
private fun serdes(obj: Any,
factory: SerializerFactory = SerializerFactory(),
freshDeserializationFactory: SerializerFactory = SerializerFactory(),
expectedEqual: Boolean = true,
expectDeserializedEqual: Boolean = true): Any {
val ser = SerializationOutput(factory)
val bytes = ser.serialize(obj)
val decoder = DecoderImpl().apply {
this.register(Envelope.DESCRIPTOR, Envelope.Companion)
this.register(Schema.DESCRIPTOR, Schema.Companion)
this.register(Descriptor.DESCRIPTOR, Descriptor.Companion)
this.register(Field.DESCRIPTOR, Field.Companion)
this.register(CompositeType.DESCRIPTOR, CompositeType.Companion)
this.register(Choice.DESCRIPTOR, Choice.Companion)
this.register(RestrictedType.DESCRIPTOR, RestrictedType.Companion)
}
EncoderImpl(decoder)
decoder.setByteBuffer(ByteBuffer.wrap(bytes.bytes, 8, bytes.size - 8))
// Check that a vanilla AMQP decoder can deserialize without schema.
val result = decoder.readObject() as Envelope
assertNotNull(result)
val des = DeserializationInput(freshDeserializationFactory)
val desObj = des.deserialize(bytes)
assertTrue(Objects.deepEquals(obj, desObj) == expectedEqual)
// Now repeat with a re-used factory
val ser2 = SerializationOutput(factory)
val des2 = DeserializationInput(factory)
val desObj2 = des2.deserialize(ser2.serialize(obj))
assertTrue(Objects.deepEquals(obj, desObj2) == expectedEqual)
assertTrue(Objects.deepEquals(desObj, desObj2) == expectDeserializedEqual)
// TODO: add some schema assertions to check correctly formed.
return desObj2
}
@Test
fun isPrimitive() {
assertTrue(isPrimitive(Character::class.java))
assertTrue(isPrimitive(Boolean::class.java))
assertTrue(isPrimitive(Byte::class.java))
assertTrue(isPrimitive(UnsignedByte::class.java))
assertTrue(isPrimitive(Short::class.java))
assertTrue(isPrimitive(UnsignedShort::class.java))
assertTrue(isPrimitive(Int::class.java))
assertTrue(isPrimitive(UnsignedInteger::class.java))
assertTrue(isPrimitive(Long::class.java))
assertTrue(isPrimitive(UnsignedLong::class.java))
assertTrue(isPrimitive(Float::class.java))
assertTrue(isPrimitive(Double::class.java))
assertTrue(isPrimitive(Decimal32::class.java))
assertTrue(isPrimitive(Decimal64::class.java))
assertTrue(isPrimitive(Decimal128::class.java))
assertTrue(isPrimitive(Char::class.java))
assertTrue(isPrimitive(Date::class.java))
assertTrue(isPrimitive(UUID::class.java))
assertTrue(isPrimitive(ByteArray::class.java))
assertTrue(isPrimitive(String::class.java))
assertTrue(isPrimitive(Symbol::class.java))
}
@Test
fun `test foo`() {
val obj = Foo("Hello World!", 123)
serdes(obj)
}
@Test
fun `test float`() {
val obj = testFloat(10.0F)
serdes(obj)
}
@Test
fun `test double`() {
val obj = testDouble(10.0)
serdes(obj)
}
@Test
fun `test short`() {
val obj = testShort(1)
serdes(obj)
}
@Test
fun `test bool`() {
val obj = testBoolean(true)
serdes(obj)
}
@Test
fun `test foo implements`() {
val obj = FooImplements("Hello World!", 123)
serdes(obj)
}
@Test
fun `test foo implements and list`() {
val obj = FooImplementsAndList("Hello World!", 123, listOf("Fred", "Ginger"))
serdes(obj)
}
@Test(expected = IllegalArgumentException::class)
fun `test dislike of HashMap`() {
val obj = WrapHashMap(HashMap<String, String>())
serdes(obj)
}
@Test
fun `test string array`() {
val obj = arrayOf("Fred", "Ginger")
serdes(obj)
}
@Test
fun `test foo array`() {
val obj = arrayOf(Foo("Fred", 1), Foo("Ginger", 2))
serdes(obj)
}
@Test(expected = NotSerializableException::class)
fun `test top level list array`() {
val obj = arrayOf(listOf("Fred", "Ginger"), listOf("Rogers", "Hammerstein"))
serdes(obj)
}
@Test
fun `test foo list array`() {
val obj = WrapFooListArray(arrayOf(listOf(Foo("Fred", 1), Foo("Ginger", 2)), listOf(Foo("Rogers", 3), Foo("Hammerstein", 4))))
serdes(obj)
}
@Test
fun `test not all properties in constructor`() {
val obj = Woo(2)
serdes(obj)
}
@Test
fun `test annotated constructor`() {
val obj = Woo2(3)
serdes(obj)
}
@Test(expected = NotSerializableException::class)
fun `test whitelist`() {
val obj = Woo2(4)
serdes(obj, SerializerFactory(EmptyWhitelist))
}
@Test
fun `test annotation whitelisting`() {
val obj = AnnotatedWoo(5)
serdes(obj, SerializerFactory(EmptyWhitelist))
}
@Test(expected = NotSerializableException::class)
fun `test generic list subclass is not supported`() {
val obj = FooList()
serdes(obj)
}
@Test
fun `test generic foo`() {
val obj = GenericFoo("Fred", "Ginger")
serdes(obj)
}
@Test
fun `test generic foo as property`() {
val obj = ContainsGenericFoo(GenericFoo("Fred", "Ginger"))
serdes(obj)
}
@Test
fun `test nested generic foo as property`() {
val obj = ContainsNestedGenericFoo(NestedGenericFoo(GenericFoo("Fred", "Ginger")))
serdes(obj)
}
// TODO: Generic interfaces / superclasses
@Test
fun `test extends generic`() {
val obj = ExtendsGeneric(1, "Ginger")
serdes(obj)
}
@Test
fun `test implements generic`() {
val obj = ImplementsGenericString(1, "Ginger")
serdes(obj)
}
@Test
fun `test implements generic captured`() {
val obj = CapturesGenericX(ImplementsGenericX(1, "Ginger"))
serdes(obj)
}
@Test
fun `test inherits generic captured`() {
val obj = CapturesGenericX(InheritGenericX(1.0, "Ginger"))
serdes(obj)
}
@Test(expected = NotSerializableException::class)
fun `test TreeMap`() {
val obj = TreeMap<Int, Foo>()
obj[456] = Foo("Fred", 123)
serdes(obj)
}
@Test(expected = NotSerializableException::class)
fun `test TreeMap property`() {
val obj = TreeMapWrapper(TreeMap<Int, Foo>())
obj.tree[456] = Foo("Fred", 123)
serdes(obj)
}
@Test
fun `test NavigableMap property`() {
val obj = NavigableMapWrapper(TreeMap<Int, Foo>())
obj.tree[456] = Foo("Fred", 123)
serdes(obj)
}
@Test
fun `test SortedSet property`() {
val obj = SortedSetWrapper(TreeSet<Int>())
obj.set += 456
serdes(obj)
}
@Test(expected = NotSerializableException::class)
fun `test mismatched property and constructor naming`() {
val obj = Mismatch(456)
serdes(obj)
}
@Test(expected = NotSerializableException::class)
fun `test mismatched property and constructor type`() {
val obj = MismatchType(456)
serdes(obj)
}
@Test
fun `test custom serializers on public key`() {
val factory = SerializerFactory()
factory.register(PublicKeySerializer)
val factory2 = SerializerFactory()
factory2.register(PublicKeySerializer)
val obj = MEGA_CORP_PUBKEY
serdes(obj, factory, factory2)
}
@Test
fun `test annotation is inherited`() {
val obj = InheritAnnotation("blah")
serdes(obj, SerializerFactory(EmptyWhitelist))
}
@Test
fun `test throwables serialize`() {
val factory = SerializerFactory()
factory.register(ThrowableSerializer(factory))
val factory2 = SerializerFactory()
factory2.register(ThrowableSerializer(factory2))
val t = IllegalAccessException("message").fillInStackTrace()
val desThrowable = serdes(t, factory, factory2, false) as Throwable
assertSerializedThrowableEquivalent(t, desThrowable)
}
@Test
fun `test complex throwables serialize`() {
val factory = SerializerFactory()
factory.register(ThrowableSerializer(factory))
val factory2 = SerializerFactory()
factory2.register(ThrowableSerializer(factory2))
try {
try {
throw IOException("Layer 1")
} catch(t: Throwable) {
throw IllegalStateException("Layer 2", t)
}
} catch(t: Throwable) {
val desThrowable = serdes(t, factory, factory2, false) as Throwable
assertSerializedThrowableEquivalent(t, desThrowable)
}
}
fun assertSerializedThrowableEquivalent(t: Throwable, desThrowable: Throwable) {
assertTrue(desThrowable is CordaRuntimeException) // Since we don't handle the other case(s) yet
if (desThrowable is CordaRuntimeException) {
assertEquals("${t.javaClass.name}: ${t.message}", desThrowable.message)
assertTrue(desThrowable is CordaRuntimeException)
assertTrue(Objects.deepEquals(t.stackTrace, desThrowable.stackTrace))
assertEquals(t.suppressed.size, desThrowable.suppressed.size)
t.suppressed.zip(desThrowable.suppressed).forEach { (before, after) -> assertSerializedThrowableEquivalent(before, after) }
}
}
@Test
fun `test suppressed throwables serialize`() {
val factory = SerializerFactory()
factory.register(ThrowableSerializer(factory))
val factory2 = SerializerFactory()
factory2.register(ThrowableSerializer(factory2))
try {
try {
throw IOException("Layer 1")
} catch(t: Throwable) {
val e = IllegalStateException("Layer 2")
e.addSuppressed(t)
throw e
}
} catch(t: Throwable) {
val desThrowable = serdes(t, factory, factory2, false) as Throwable
assertSerializedThrowableEquivalent(t, desThrowable)
}
}
@Test
fun `test flow corda exception subclasses serialize`() {
val factory = SerializerFactory()
factory.register(ThrowableSerializer(factory))
val factory2 = SerializerFactory()
factory2.register(ThrowableSerializer(factory2))
val obj = FlowException("message").fillInStackTrace()
serdes(obj, factory, factory2)
}
@Test
fun `test RPC corda exception subclasses serialize`() {
val factory = SerializerFactory()
factory.register(ThrowableSerializer(factory))
val factory2 = SerializerFactory()
factory2.register(ThrowableSerializer(factory2))
val obj = RPCException("message").fillInStackTrace()
serdes(obj, factory, factory2)
}
@Test
fun `test polymorphic property`() {
val obj = PolymorphicProperty(FooImplements("Ginger", 12))
serdes(obj)
}
@Test
fun `test null polymorphic property`() {
val obj = PolymorphicProperty(null)
serdes(obj)
}
@Test
fun `test kotlin object`() {
serdes(KotlinObject)
}
object FooContract : Contract {
override fun verify(tx: LedgerTransaction) {
}
override val legalContractReference: SecureHash = SecureHash.Companion.sha256("FooContractLegal")
}
class FooState : ContractState {
override val contract: Contract
get() = FooContract
override val participants: List<AbstractParty>
get() = emptyList()
}
@Test
fun `test transaction state`() {
val state = TransactionState<FooState>(FooState(), MEGA_CORP)
val factory = SerializerFactory()
AbstractAMQPSerializationScheme.registerCustomSerializers(factory)
val factory2 = SerializerFactory()
AbstractAMQPSerializationScheme.registerCustomSerializers(factory2)
val desState = serdes(state, factory, factory2, expectedEqual = false, expectDeserializedEqual = false)
assertTrue(desState is TransactionState<*>)
assertTrue((desState as TransactionState<*>).data is FooState)
assertTrue(desState.notary == state.notary)
assertTrue(desState.encumbrance == state.encumbrance)
}
@Test
fun `test currencies serialize`() {
val factory = SerializerFactory()
factory.register(CurrencySerializer)
val factory2 = SerializerFactory()
factory2.register(CurrencySerializer)
val obj = Currency.getInstance("USD")
serdes(obj, factory, factory2)
}
@Test
fun `test big decimals serialize`() {
val factory = SerializerFactory()
factory.register(BigDecimalSerializer)
val factory2 = SerializerFactory()
factory2.register(BigDecimalSerializer)
val obj = BigDecimal("100000000000000000000000000000.00")
serdes(obj, factory, factory2)
}
@Test
fun `test instants serialize`() {
val factory = SerializerFactory()
factory.register(InstantSerializer(factory))
val factory2 = SerializerFactory()
factory2.register(InstantSerializer(factory2))
val obj = Instant.now()
serdes(obj, factory, factory2)
}
@Test
fun `test StateRef serialize`() {
val factory = SerializerFactory()
factory.register(InstantSerializer(factory))
val factory2 = SerializerFactory()
factory2.register(InstantSerializer(factory2))
val obj = StateRef(SecureHash.randomSHA256(), 0)
serdes(obj, factory, factory2)
}
}

View File

@ -0,0 +1,498 @@
package net.corda.nodeapi.internal.serialization.carpenter
import org.junit.Test
import java.beans.Introspector
import java.lang.reflect.Field
import java.lang.reflect.Method
import javax.annotation.Nonnull
import javax.annotation.Nullable
import kotlin.test.assertEquals
import kotlin.test.assertNotEquals
class ClassCarpenterTest {
interface DummyInterface {
val a: String
val b: Int
}
val cc = ClassCarpenter()
// We have to ignore synthetic fields even though ClassCarpenter doesn't create any because the JaCoCo
// coverage framework auto-magically injects one method and one field into every class loaded into the JVM.
val Class<*>.nonSyntheticFields: List<Field> get() = declaredFields.filterNot { it.isSynthetic }
val Class<*>.nonSyntheticMethods: List<Method> get() = declaredMethods.filterNot { it.isSynthetic }
@Test
fun empty() {
val clazz = cc.build(ClassSchema("gen.EmptyClass", emptyMap(), null))
assertEquals(0, clazz.nonSyntheticFields.size)
assertEquals(2, clazz.nonSyntheticMethods.size) // get, toString
assertEquals(0, clazz.declaredConstructors[0].parameterCount)
clazz.newInstance() // just test there's no exception.
}
@Test
fun prims() {
val clazz = cc.build(ClassSchema(
"gen.Prims",
mapOf(
"anIntField" to Int::class.javaPrimitiveType!!,
"aLongField" to Long::class.javaPrimitiveType!!,
"someCharField" to Char::class.javaPrimitiveType!!,
"aShortField" to Short::class.javaPrimitiveType!!,
"doubleTrouble" to Double::class.javaPrimitiveType!!,
"floatMyBoat" to Float::class.javaPrimitiveType!!,
"byteMe" to Byte::class.javaPrimitiveType!!,
"booleanField" to Boolean::class.javaPrimitiveType!!).mapValues {
NonNullableField(it.value)
}))
assertEquals(8, clazz.nonSyntheticFields.size)
assertEquals(10, clazz.nonSyntheticMethods.size)
assertEquals(8, clazz.declaredConstructors[0].parameterCount)
val i = clazz.constructors[0].newInstance(1, 2L, 'c', 4.toShort(), 1.23, 4.56F, 127.toByte(), true)
assertEquals(1, clazz.getMethod("getAnIntField").invoke(i))
assertEquals(2L, clazz.getMethod("getALongField").invoke(i))
assertEquals('c', clazz.getMethod("getSomeCharField").invoke(i))
assertEquals(4.toShort(), clazz.getMethod("getAShortField").invoke(i))
assertEquals(1.23, clazz.getMethod("getDoubleTrouble").invoke(i))
assertEquals(4.56F, clazz.getMethod("getFloatMyBoat").invoke(i))
assertEquals(127.toByte(), clazz.getMethod("getByteMe").invoke(i))
assertEquals(true, clazz.getMethod("getBooleanField").invoke(i))
val sfa = i as SimpleFieldAccess
assertEquals(1, sfa["anIntField"])
assertEquals(2L, sfa["aLongField"])
assertEquals('c', sfa["someCharField"])
assertEquals(4.toShort(), sfa["aShortField"])
assertEquals(1.23, sfa["doubleTrouble"])
assertEquals(4.56F, sfa["floatMyBoat"])
assertEquals(127.toByte(), sfa["byteMe"])
assertEquals(true, sfa["booleanField"])
}
private fun genPerson(): Pair<Class<*>, Any> {
val clazz = cc.build(ClassSchema("gen.Person", mapOf(
"age" to Int::class.javaPrimitiveType!!,
"name" to String::class.java
).mapValues { NonNullableField(it.value) }))
val i = clazz.constructors[0].newInstance(32, "Mike")
return Pair(clazz, i)
}
@Test
fun objs() {
val (clazz, i) = genPerson()
assertEquals("Mike", clazz.getMethod("getName").invoke(i))
assertEquals("Mike", (i as SimpleFieldAccess)["name"])
}
@Test
fun `generated toString`() {
val (_, i) = genPerson()
assertEquals("Person{age=32, name=Mike}", i.toString())
}
@Test(expected = DuplicateNameException::class)
fun duplicates() {
cc.build(ClassSchema("gen.EmptyClass", emptyMap(), null))
cc.build(ClassSchema("gen.EmptyClass", emptyMap(), null))
}
@Test
fun `can refer to each other`() {
val (clazz1, i) = genPerson()
val clazz2 = cc.build(ClassSchema("gen.Referee", mapOf(
"ref" to NonNullableField(clazz1)
)))
val i2 = clazz2.constructors[0].newInstance(i)
assertEquals(i, (i2 as SimpleFieldAccess)["ref"])
}
@Test
fun superclasses() {
val schema1 = ClassSchema(
"gen.A",
mapOf("a" to NonNullableField(String::class.java)))
val schema2 = ClassSchema(
"gen.B",
mapOf("b" to NonNullableField(String::class.java)),
schema1)
val clazz = cc.build(schema2)
val i = clazz.constructors[0].newInstance("xa", "xb") as SimpleFieldAccess
assertEquals("xa", i["a"])
assertEquals("xb", i["b"])
assertEquals("B{a=xa, b=xb}", i.toString())
}
@Test
fun interfaces() {
val schema1 = ClassSchema(
"gen.A",
mapOf("a" to NonNullableField(String::class.java)))
val schema2 = ClassSchema("gen.B",
mapOf("b" to NonNullableField(Int::class.java)),
schema1,
interfaces = listOf(DummyInterface::class.java))
val clazz = cc.build(schema2)
val i = clazz.constructors[0].newInstance("xa", 1) as DummyInterface
assertEquals("xa", i.a)
assertEquals(1, i.b)
}
@Test(expected = InterfaceMismatchException::class)
fun `mismatched interface`() {
val schema1 = ClassSchema(
"gen.A",
mapOf("a" to NonNullableField(String::class.java)))
val schema2 = ClassSchema(
"gen.B",
mapOf("c" to NonNullableField(Int::class.java)),
schema1,
interfaces = listOf(DummyInterface::class.java))
val clazz = cc.build(schema2)
val i = clazz.constructors[0].newInstance("xa", 1) as DummyInterface
assertEquals(1, i.b)
}
@Test
fun `generate interface`() {
val schema1 = InterfaceSchema(
"gen.Interface",
mapOf("a" to NonNullableField(Int::class.java)))
val iface = cc.build(schema1)
assert(iface.isInterface)
assert(iface.constructors.isEmpty())
assertEquals(iface.declaredMethods.size, 1)
assertEquals(iface.declaredMethods[0].name, "getA")
val schema2 = ClassSchema(
"gen.Derived",
mapOf("a" to NonNullableField(Int::class.java)),
interfaces = listOf(iface))
val clazz = cc.build(schema2)
val testA = 42
val i = clazz.constructors[0].newInstance(testA) as SimpleFieldAccess
assertEquals(testA, i["a"])
}
@Test
fun `generate multiple interfaces`() {
val iFace1 = InterfaceSchema(
"gen.Interface1",
mapOf(
"a" to NonNullableField(Int::class.java),
"b" to NonNullableField(String::class.java)))
val iFace2 = InterfaceSchema(
"gen.Interface2",
mapOf(
"c" to NonNullableField(Int::class.java),
"d" to NonNullableField(String::class.java)))
val class1 = ClassSchema(
"gen.Derived",
mapOf(
"a" to NonNullableField(Int::class.java),
"b" to NonNullableField(String::class.java),
"c" to NonNullableField(Int::class.java),
"d" to NonNullableField(String::class.java)),
interfaces = listOf(cc.build(iFace1), cc.build(iFace2)))
val clazz = cc.build(class1)
val testA = 42
val testB = "don't touch me, I'm scared"
val testC = 0xDEAD
val testD = "wibble"
val i = clazz.constructors[0].newInstance(testA, testB, testC, testD) as SimpleFieldAccess
assertEquals(testA, i["a"])
assertEquals(testB, i["b"])
assertEquals(testC, i["c"])
assertEquals(testD, i["d"])
}
@Test
fun `interface implementing interface`() {
val iFace1 = InterfaceSchema(
"gen.Interface1",
mapOf(
"a" to NonNullableField(Int::class.java),
"b" to NonNullableField(String::class.java)))
val iFace2 = InterfaceSchema(
"gen.Interface2",
mapOf(
"c" to NonNullableField(Int::class.java),
"d" to NonNullableField(String::class.java)),
interfaces = listOf(cc.build(iFace1)))
val class1 = ClassSchema(
"gen.Derived",
mapOf(
"a" to NonNullableField(Int::class.java),
"b" to NonNullableField(String::class.java),
"c" to NonNullableField(Int::class.java),
"d" to NonNullableField(String::class.java)),
interfaces = listOf(cc.build(iFace2)))
val clazz = cc.build(class1)
val testA = 99
val testB = "green is not a creative colour"
val testC = 7
val testD = "I like jam"
val i = clazz.constructors[0].newInstance(testA, testB, testC, testD) as SimpleFieldAccess
assertEquals(testA, i["a"])
assertEquals(testB, i["b"])
assertEquals(testC, i["c"])
assertEquals(testD, i["d"])
}
@Test(expected = java.lang.IllegalArgumentException::class)
fun `null parameter small int`() {
val className = "iEnjoySwede"
val schema = ClassSchema(
"gen.$className",
mapOf("a" to NonNullableField(Int::class.java)))
val clazz = cc.build(schema)
val a : Int? = null
clazz.constructors[0].newInstance(a)
}
@Test(expected = NullablePrimitiveException::class)
fun `nullable parameter small int`() {
val className = "iEnjoySwede"
val schema = ClassSchema(
"gen.$className",
mapOf("a" to NullableField(Int::class.java)))
cc.build(schema)
}
@Test
fun `nullable parameter integer`() {
val className = "iEnjoyWibble"
val schema = ClassSchema(
"gen.$className",
mapOf("a" to NullableField(Integer::class.java)))
val clazz = cc.build(schema)
val a1 : Int? = null
clazz.constructors[0].newInstance(a1)
val a2 : Int? = 10
clazz.constructors[0].newInstance(a2)
}
@Test
fun `non nullable parameter integer with non null`() {
val className = "iEnjoyWibble"
val schema = ClassSchema(
"gen.$className",
mapOf("a" to NonNullableField(Integer::class.java)))
val clazz = cc.build(schema)
val a : Int? = 10
clazz.constructors[0].newInstance(a)
}
@Test(expected = java.lang.reflect.InvocationTargetException::class)
fun `non nullable parameter integer with null`() {
val className = "iEnjoyWibble"
val schema = ClassSchema(
"gen.$className",
mapOf("a" to NonNullableField(Integer::class.java)))
val clazz = cc.build(schema)
val a : Int? = null
clazz.constructors[0].newInstance(a)
}
@Test
@Suppress("UNCHECKED_CAST")
fun `int array`() {
val className = "iEnjoyPotato"
val schema = ClassSchema(
"gen.$className",
mapOf("a" to NonNullableField(IntArray::class.java)))
val clazz = cc.build(schema)
val i = clazz.constructors[0].newInstance(intArrayOf(1, 2, 3)) as SimpleFieldAccess
val arr = clazz.getMethod("getA").invoke(i)
assertEquals(1, (arr as IntArray)[0])
assertEquals(2, arr[1])
assertEquals(3, arr[2])
assertEquals("$className{a=[1, 2, 3]}", i.toString())
}
@Test(expected = java.lang.reflect.InvocationTargetException::class)
fun `nullable int array throws`() {
val className = "iEnjoySwede"
val schema = ClassSchema(
"gen.$className",
mapOf("a" to NonNullableField(IntArray::class.java)))
val clazz = cc.build(schema)
val a : IntArray? = null
clazz.constructors[0].newInstance(a)
}
@Test
@Suppress("UNCHECKED_CAST")
fun `integer array`() {
val className = "iEnjoyFlan"
val schema = ClassSchema(
"gen.$className",
mapOf("a" to NonNullableField(Array<Int>::class.java)))
val clazz = cc.build(schema)
val i = clazz.constructors[0].newInstance(arrayOf(1, 2, 3)) as SimpleFieldAccess
val arr = clazz.getMethod("getA").invoke(i)
assertEquals(1, (arr as Array<Int>)[0])
assertEquals(2, arr[1])
assertEquals(3, arr[2])
assertEquals("$className{a=[1, 2, 3]}", i.toString())
}
@Test
@Suppress("UNCHECKED_CAST")
fun `int array with ints`() {
val className = "iEnjoyCrumble"
val schema = ClassSchema(
"gen.$className", mapOf(
"a" to Int::class.java,
"b" to IntArray::class.java,
"c" to Int::class.java).mapValues { NonNullableField(it.value) })
val clazz = cc.build(schema)
val i = clazz.constructors[0].newInstance(2, intArrayOf(4, 8), 16) as SimpleFieldAccess
assertEquals(2, clazz.getMethod("getA").invoke(i))
assertEquals(4, (clazz.getMethod("getB").invoke(i) as IntArray)[0])
assertEquals(8, (clazz.getMethod("getB").invoke(i) as IntArray)[1])
assertEquals(16, clazz.getMethod("getC").invoke(i))
assertEquals("$className{a=2, b=[4, 8], c=16}", i.toString())
}
@Test
@Suppress("UNCHECKED_CAST")
fun `multiple int arrays`() {
val className = "iEnjoyJam"
val schema = ClassSchema(
"gen.$className", mapOf(
"a" to IntArray::class.java,
"b" to Int::class.java,
"c" to IntArray::class.java).mapValues { NonNullableField(it.value) })
val clazz = cc.build(schema)
val i = clazz.constructors[0].newInstance(intArrayOf(1, 2), 3, intArrayOf(4, 5, 6))
assertEquals(1, (clazz.getMethod("getA").invoke(i) as IntArray)[0])
assertEquals(2, (clazz.getMethod("getA").invoke(i) as IntArray)[1])
assertEquals(3, clazz.getMethod("getB").invoke(i))
assertEquals(4, (clazz.getMethod("getC").invoke(i) as IntArray)[0])
assertEquals(5, (clazz.getMethod("getC").invoke(i) as IntArray)[1])
assertEquals(6, (clazz.getMethod("getC").invoke(i) as IntArray)[2])
assertEquals("$className{a=[1, 2], b=3, c=[4, 5, 6]}", i.toString())
}
@Test
@Suppress("UNCHECKED_CAST")
fun `string array`() {
val className = "iEnjoyToast"
val schema = ClassSchema(
"gen.$className",
mapOf("a" to NullableField(Array<String>::class.java)))
val clazz = cc.build(schema)
val i = clazz.constructors[0].newInstance(arrayOf("toast", "butter", "jam"))
val arr = clazz.getMethod("getA").invoke(i) as Array<String>
assertEquals("toast", arr[0])
assertEquals("butter", arr[1])
assertEquals("jam", arr[2])
}
@Test
@Suppress("UNCHECKED_CAST")
fun `string arrays`() {
val className = "iEnjoyToast"
val schema = ClassSchema(
"gen.$className",
mapOf(
"a" to Array<String>::class.java,
"b" to String::class.java,
"c" to Array<String>::class.java).mapValues { NullableField(it.value) })
val clazz = cc.build(schema)
val i = clazz.constructors[0].newInstance(
arrayOf("bread", "spread", "cheese"),
"and on the side",
arrayOf("some pickles", "some fries"))
val arr1 = clazz.getMethod("getA").invoke(i) as Array<String>
val arr2 = clazz.getMethod("getC").invoke(i) as Array<String>
assertEquals("bread", arr1[0])
assertEquals("spread", arr1[1])
assertEquals("cheese", arr1[2])
assertEquals("and on the side", clazz.getMethod("getB").invoke(i))
assertEquals("some pickles", arr2[0])
assertEquals("some fries", arr2[1])
}
@Test
fun `nullable sets annotations`() {
val className = "iEnjoyJam"
val schema = ClassSchema(
"gen.$className",
mapOf("a" to NullableField(String::class.java),
"b" to NonNullableField(String::class.java)))
val clazz = cc.build(schema)
assertEquals (2, clazz.declaredFields.size)
assertEquals (1, clazz.getDeclaredField("a").annotations.size)
assertEquals(Nullable::class.java, clazz.getDeclaredField("a").annotations[0].annotationClass.java)
assertEquals (1, clazz.getDeclaredField("b").annotations.size)
assertEquals(Nonnull::class.java, clazz.getDeclaredField("b").annotations[0].annotationClass.java)
assertEquals (1, clazz.getMethod("getA").annotations.size)
assertEquals(Nullable::class.java, clazz.getMethod("getA").annotations[0].annotationClass.java)
assertEquals (1, clazz.getMethod("getB").annotations.size)
assertEquals(Nonnull::class.java, clazz.getMethod("getB").annotations[0].annotationClass.java)
}
@Test
fun beanTest() {
val schema = ClassSchema(
"pantsPantsPants",
mapOf("a" to NonNullableField(Integer::class.java)))
val clazz = cc.build(schema)
val descriptors = Introspector.getBeanInfo(clazz).propertyDescriptors
assertEquals(2, descriptors.size)
assertNotEquals(null, descriptors.find { it.name == "a" })
assertNotEquals(null, descriptors.find { it.name == "class" })
}
}

View File

@ -0,0 +1,46 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.nodeapi.internal.serialization.amqp.Field
import net.corda.nodeapi.internal.serialization.amqp.Schema
import net.corda.nodeapi.internal.serialization.amqp.TypeNotation
import net.corda.nodeapi.internal.serialization.amqp.CompositeType
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput
fun mangleName(name: String) = "${name}__carpenter"
/**
* given a list of class names work through the amqp envelope schema and alter any that
* match in the fashion defined above
*/
fun Schema.mangleNames(names: List<String>): Schema {
val newTypes: MutableList<TypeNotation> = mutableListOf()
for (type in types) {
val newName = if (type.name in names) mangleName(type.name) else type.name
val newProvides = type.provides.map { if (it in names) mangleName(it) else it }
val newFields = mutableListOf<Field>()
(type as CompositeType).fields.forEach {
val fieldType = if (it.type in names) mangleName(it.type) else it.type
val requires =
if (it.requires.isNotEmpty() && (it.requires[0] in names)) listOf(mangleName(it.requires[0]))
else it.requires
newFields.add(it.copy(type = fieldType, requires = requires))
}
newTypes.add(type.copy(name = newName, provides = newProvides, fields = newFields))
}
return Schema(types = newTypes)
}
open class AmqpCarpenterBase {
var factory = SerializerFactory()
fun serialise(clazz: Any) = SerializationOutput(factory).serialize(clazz)
fun testName(): String = Thread.currentThread().stackTrace[2].methodName
@Suppress("NOTHING_TO_INLINE")
inline fun classTestName(clazz: String) = "${this.javaClass.name}\$${testName()}\$$clazz"
}

View File

@ -0,0 +1,285 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.internal.serialization.amqp.CompositeType
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
@CordaSerializable
interface I_ {
val a: Int
}
class CompositeMembers : AmqpCarpenterBase() {
@Test
fun bothKnown() {
val testA = 10
val testB = 20
@CordaSerializable
data class A(val a: Int)
@CordaSerializable
data class B(val a: A, var b: Int)
val b = B(A(testA), testB)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b))
assert(obj.obj is B)
val amqpObj = obj.obj as B
assertEquals(testB, amqpObj.b)
assertEquals(testA, amqpObj.a.a)
assertEquals(2, obj.envelope.schema.types.size)
assert(obj.envelope.schema.types[0] is CompositeType)
assert(obj.envelope.schema.types[1] is CompositeType)
var amqpSchemaA: CompositeType? = null
var amqpSchemaB: CompositeType? = null
for (type in obj.envelope.schema.types) {
when (type.name.split ("$").last()) {
"A" -> amqpSchemaA = type as CompositeType
"B" -> amqpSchemaB = type as CompositeType
}
}
assert(amqpSchemaA != null)
assert(amqpSchemaB != null)
// Just ensure the amqp schema matches what we want before we go messing
// around with the internals
assertEquals(1, amqpSchemaA?.fields?.size)
assertEquals("a", amqpSchemaA!!.fields[0].name)
assertEquals("int", amqpSchemaA.fields[0].type)
assertEquals(2, amqpSchemaB?.fields?.size)
assertEquals("a", amqpSchemaB!!.fields[0].name)
assertEquals(classTestName("A"), amqpSchemaB.fields[0].type)
assertEquals("b", amqpSchemaB.fields[1].name)
assertEquals("int", amqpSchemaB.fields[1].type)
val metaSchema = obj.envelope.schema.carpenterSchema()
// if we know all the classes there is nothing to really achieve here
assert(metaSchema.carpenterSchemas.isEmpty())
assert(metaSchema.dependsOn.isEmpty())
assert(metaSchema.dependencies.isEmpty())
}
// you cannot have an element of a composite class we know about
// that is unknown as that should be impossible. If we have the containing
// class in the class path then we must have all of it's constituent elements
@Test(expected = UncarpentableException::class)
fun nestedIsUnknown() {
val testA = 10
val testB = 20
@CordaSerializable
data class A(override val a: Int) : I_
@CordaSerializable
data class B(val a: A, var b: Int)
val b = B(A(testA), testB)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b))
val amqpSchema = obj.envelope.schema.mangleNames(listOf (classTestName ("A")))
assert(obj.obj is B)
amqpSchema.carpenterSchema()
}
@Test
fun ParentIsUnknown() {
val testA = 10
val testB = 20
@CordaSerializable
data class A(override val a: Int) : I_
@CordaSerializable
data class B(val a: A, var b: Int)
val b = B(A(testA), testB)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b))
assert(obj.obj is B)
val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("B")))
val carpenterSchema = amqpSchema.carpenterSchema()
assertEquals(1, carpenterSchema.size)
val metaCarpenter = MetaCarpenter(carpenterSchema)
metaCarpenter.build()
assert(mangleName(classTestName("B")) in metaCarpenter.objects)
}
@Test
fun BothUnknown() {
val testA = 10
val testB = 20
@CordaSerializable
data class A(override val a: Int) : I_
@CordaSerializable
data class B(val a: A, var b: Int)
val b = B(A(testA), testB)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b))
assert(obj.obj is B)
val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B")))
val carpenterSchema = amqpSchema.carpenterSchema()
// just verify we're in the expected initial state, A is carpentable, B is not because
// it depends on A and the dependency chains are in place
assertEquals(1, carpenterSchema.size)
assertEquals(mangleName(classTestName("A")), carpenterSchema.carpenterSchemas.first().name)
assertEquals(1, carpenterSchema.dependencies.size)
assert(mangleName(classTestName("B")) in carpenterSchema.dependencies)
assertEquals(1, carpenterSchema.dependsOn.size)
assert(mangleName(classTestName("A")) in carpenterSchema.dependsOn)
val metaCarpenter = TestMetaCarpenter(carpenterSchema)
assertEquals(0, metaCarpenter.objects.size)
// first iteration, carpent A, resolve deps and mark B as carpentable
metaCarpenter.build()
// one build iteration should have carpetned up A and worked out that B is now buildable
// given it's depedencies have been satisfied
assertTrue(mangleName(classTestName("A")) in metaCarpenter.objects)
assertFalse(mangleName(classTestName("B")) in metaCarpenter.objects)
assertEquals(1, carpenterSchema.carpenterSchemas.size)
assertEquals(mangleName(classTestName("B")), carpenterSchema.carpenterSchemas.first().name)
assertTrue(carpenterSchema.dependencies.isEmpty())
assertTrue(carpenterSchema.dependsOn.isEmpty())
// second manual iteration, will carpent B
metaCarpenter.build()
assert(mangleName(classTestName("A")) in metaCarpenter.objects)
assert(mangleName(classTestName("B")) in metaCarpenter.objects)
// and we must be finished
assertTrue(carpenterSchema.carpenterSchemas.isEmpty())
}
@Test(expected = UncarpentableException::class)
@Suppress("UNUSED")
fun nestedIsUnknownInherited() {
val testA = 10
val testB = 20
val testC = 30
@CordaSerializable
open class A(val a: Int)
@CordaSerializable
class B(a: Int, var b: Int) : A(a)
@CordaSerializable
data class C(val b: B, var c: Int)
val c = C(B(testA, testB), testC)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(c))
assert(obj.obj is C)
val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B")))
amqpSchema.carpenterSchema()
}
@Test(expected = UncarpentableException::class)
@Suppress("UNUSED")
fun nestedIsUnknownInheritedUnknown() {
val testA = 10
val testB = 20
val testC = 30
@CordaSerializable
open class A(val a: Int)
@CordaSerializable
class B(a: Int, var b: Int) : A(a)
@CordaSerializable
data class C(val b: B, var c: Int)
val c = C(B(testA, testB), testC)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(c))
assert(obj.obj is C)
val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B")))
amqpSchema.carpenterSchema()
}
@Suppress("UNUSED")
@Test(expected = UncarpentableException::class)
fun parentsIsUnknownWithUnknownInheritedMember() {
val testA = 10
val testB = 20
val testC = 30
@CordaSerializable
open class A(val a: Int)
@CordaSerializable
class B(a: Int, var b: Int) : A(a)
@CordaSerializable
data class C(val b: B, var c: Int)
val c = C(B(testA, testB), testC)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(c))
assert(obj.obj is C)
val carpenterSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B")))
TestMetaCarpenter(carpenterSchema.carpenterSchema())
}
/*
* TODO serializer doesn't support inheritnace at the moment, when it does this should work
@Test
fun `inheritance`() {
val testA = 10
val testB = 20
@CordaSerializable
open class A(open val a: Int)
@CordaSerializable
class B(override val a: Int, val b: Int) : A (a)
val b = B(testA, testB)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b))
assert(obj.obj is B)
val carpenterSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B")))
val metaCarpenter = TestMetaCarpenter(carpenterSchema.carpenterSchema())
assertEquals(1, metaCarpenter.schemas.carpenterSchemas.size)
assertEquals(mangleNames(classTestName("B")), metaCarpenter.schemas.carpenterSchemas.first().name)
assertEquals(1, metaCarpenter.schemas.dependencies.size)
assertTrue(mangleNames(classTestName("A")) in metaCarpenter.schemas.dependencies)
}
*/
}

View File

@ -0,0 +1,459 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import org.junit.Test
import kotlin.test.*
@CordaSerializable
interface J {
val j: Int
}
@CordaSerializable
interface I {
val i: Int
}
@CordaSerializable
interface II {
val ii: Int
}
@CordaSerializable
interface III : I {
val iii: Int
override val i: Int
}
@CordaSerializable
interface IIII {
val iiii: Int
val i: I
}
class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase() {
@Test
fun interfaceParent1() {
class A(override val j: Int) : J
val testJ = 20
val a = A(testJ)
assertEquals(testJ, a.j)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assertTrue(obj.obj is A)
val serSchema = obj.envelope.schema
assertEquals(2, serSchema.types.size)
val l1 = serSchema.carpenterSchema()
// since we're using an envelope generated by seilaising classes defined locally
// it's extremely unlikely we'd need to carpent any classes
assertEquals(0, l1.size)
val mangleSchema = serSchema.mangleNames(listOf(classTestName("A")))
val l2 = mangleSchema.carpenterSchema()
assertEquals(1, l2.size)
val aSchema = l2.carpenterSchemas.find { it.name == mangleName(classTestName("A")) }
assertNotEquals(null, aSchema)
assertEquals(mangleName(classTestName("A")), aSchema!!.name)
assertEquals(1, aSchema.interfaces.size)
assertEquals(J::class.java, aSchema.interfaces[0])
val aBuilder = ClassCarpenter().build(aSchema)
val objJ = aBuilder.constructors[0].newInstance(testJ)
val j = objJ as J
assertEquals(aBuilder.getMethod("getJ").invoke(objJ), testJ)
assertEquals(a.j, j.j)
}
@Test
fun interfaceParent2() {
class A(override val j: Int, val jj: Int) : J
val testJ = 20
val testJJ = 40
val a = A(testJ, testJJ)
assertEquals(testJ, a.j)
assertEquals(testJJ, a.jj)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assertTrue(obj.obj is A)
val serSchema = obj.envelope.schema
assertEquals(2, serSchema.types.size)
val l1 = serSchema.carpenterSchema()
assertEquals(0, l1.size)
val mangleSchema = serSchema.mangleNames(listOf(classTestName("A")))
val aName = mangleName(classTestName("A"))
val l2 = mangleSchema.carpenterSchema()
assertEquals(1, l2.size)
val aSchema = l2.carpenterSchemas.find { it.name == aName }
assertNotEquals(null, aSchema)
assertEquals(aName, aSchema!!.name)
assertEquals(1, aSchema.interfaces.size)
assertEquals(J::class.java, aSchema.interfaces[0])
val aBuilder = ClassCarpenter().build(aSchema)
val objJ = aBuilder.constructors[0].newInstance(testJ, testJJ)
val j = objJ as J
assertEquals(aBuilder.getMethod("getJ").invoke(objJ), testJ)
assertEquals(aBuilder.getMethod("getJj").invoke(objJ), testJJ)
assertEquals(a.j, j.j)
}
@Test
fun multipleInterfaces() {
val testI = 20
val testII = 40
class A(override val i: Int, override val ii: Int) : I, II
val a = A(testI, testII)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assertTrue(obj.obj is A)
val serSchema = obj.envelope.schema
assertEquals(3, serSchema.types.size)
val l1 = serSchema.carpenterSchema()
// since we're using an envelope generated by serialising classes defined locally
// it's extremely unlikely we'd need to carpent any classes
assertEquals(0, l1.size)
// pretend we don't know the class we've been sent, i.e. it's unknown to the class loader, and thus
// needs some carpentry
val mangleSchema = serSchema.mangleNames(listOf(classTestName("A")))
val l2 = mangleSchema.carpenterSchema()
val aName = mangleName(classTestName("A"))
assertEquals(1, l2.size)
val aSchema = l2.carpenterSchemas.find { it.name == aName }
assertNotEquals(null, aSchema)
assertEquals(aName, aSchema!!.name)
assertEquals(2, aSchema.interfaces.size)
assertTrue(I::class.java in aSchema.interfaces)
assertTrue(II::class.java in aSchema.interfaces)
val aBuilder = ClassCarpenter().build(aSchema)
val objA = aBuilder.constructors[0].newInstance(testI, testII)
val i = objA as I
val ii = objA as II
assertEquals(aBuilder.getMethod("getI").invoke(objA), testI)
assertEquals(aBuilder.getMethod("getIi").invoke(objA), testII)
assertEquals(a.i, i.i)
assertEquals(a.ii, ii.ii)
}
@Test
fun nestedInterfaces() {
class A(override val i: Int, override val iii: Int) : III
val testI = 20
val testIII = 60
val a = A(testI, testIII)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assertTrue(obj.obj is A)
val serSchema = obj.envelope.schema
assertEquals(3, serSchema.types.size)
val l1 = serSchema.carpenterSchema()
// since we're using an envelope generated by serialising classes defined locally
// it's extremely unlikely we'd need to carpent any classes
assertEquals(0, l1.size)
val mangleSchema = serSchema.mangleNames(listOf(classTestName("A")))
val l2 = mangleSchema.carpenterSchema()
val aName = mangleName(classTestName("A"))
assertEquals(1, l2.size)
val aSchema = l2.carpenterSchemas.find { it.name == aName }
assertNotEquals(null, aSchema)
assertEquals(aName, aSchema!!.name)
assertEquals(2, aSchema.interfaces.size)
assertTrue(I::class.java in aSchema.interfaces)
assertTrue(III::class.java in aSchema.interfaces)
val aBuilder = ClassCarpenter().build(aSchema)
val objA = aBuilder.constructors[0].newInstance(testI, testIII)
val i = objA as I
val iii = objA as III
assertEquals(aBuilder.getMethod("getI").invoke(objA), testI)
assertEquals(aBuilder.getMethod("getIii").invoke(objA), testIII)
assertEquals(a.i, i.i)
assertEquals(a.i, iii.i)
assertEquals(a.iii, iii.iii)
}
@Test
fun memberInterface() {
class A(override val i: Int) : I
class B(override val i: I, override val iiii: Int) : IIII
val testI = 25
val testIIII = 50
val a = A(testI)
val b = B(a, testIIII)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b))
assertTrue(obj.obj is B)
val serSchema = obj.envelope.schema
// Expected classes are
// * class A
// * class A's interface (class I)
// * class B
// * class B's interface (class IIII)
assertEquals(4, serSchema.types.size)
val mangleSchema = serSchema.mangleNames(listOf(classTestName("A"), classTestName("B")))
val cSchema = mangleSchema.carpenterSchema()
val aName = mangleName(classTestName("A"))
val bName = mangleName(classTestName("B"))
assertEquals(2, cSchema.size)
val aCarpenterSchema = cSchema.carpenterSchemas.find { it.name == aName }
val bCarpenterSchema = cSchema.carpenterSchemas.find { it.name == bName }
assertNotEquals(null, aCarpenterSchema)
assertNotEquals(null, bCarpenterSchema)
val cc = ClassCarpenter()
val cc2 = ClassCarpenter()
val bBuilder = cc.build(bCarpenterSchema!!)
bBuilder.constructors[0].newInstance(a, testIIII)
val aBuilder = cc.build(aCarpenterSchema!!)
val objA = aBuilder.constructors[0].newInstance(testI)
// build a second B this time using our constructed instance of A and not the
// local one we pre defined
bBuilder.constructors[0].newInstance(objA, testIIII)
// whittle and instantiate a different A with a new class loader
val aBuilder2 = cc2.build(aCarpenterSchema)
val objA2 = aBuilder2.constructors[0].newInstance(testI)
bBuilder.constructors[0].newInstance(objA2, testIIII)
}
// if we remove the nested interface we should get an error as it's impossible
// to have a concrete class loaded without having access to all of it's elements
@Test(expected = UncarpentableException::class)
fun memberInterface2() {
class A(override val i: Int) : I
class B(override val i: I, override val iiii: Int) : IIII
val testI = 25
val testIIII = 50
val a = A(testI)
val b = B(a, testIIII)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b))
assertTrue(obj.obj is B)
val serSchema = obj.envelope.schema
// The classes we're expecting to find:
// * class A
// * class A's interface (class I)
// * class B
// * class B's interface (class IIII)
assertEquals(4, serSchema.types.size)
// ignore the return as we expect this to throw
serSchema.mangleNames(listOf(
classTestName("A"), "${this.javaClass.`package`.name}.I")).carpenterSchema()
}
@Test
fun interfaceAndImplementation() {
class A(override val i: Int) : I
val testI = 25
val a = A(testI)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assertTrue(obj.obj is A)
val serSchema = obj.envelope.schema
// The classes we're expecting to find:
// * class A
// * class A's interface (class I)
assertEquals(2, serSchema.types.size)
val amqpSchema = serSchema.mangleNames(listOf(classTestName("A"), "${this.javaClass.`package`.name}.I"))
val aName = mangleName(classTestName("A"))
val iName = mangleName("${this.javaClass.`package`.name}.I")
val carpenterSchema = amqpSchema.carpenterSchema()
// whilst there are two unknown classes within the envelope A depends on I so we can't construct a
// schema for A until we have for I
assertEquals(1, carpenterSchema.size)
assertNotEquals(null, carpenterSchema.carpenterSchemas.find { it.name == iName })
// since we can't build A it should list I as a dependency
assertTrue(aName in carpenterSchema.dependencies)
assertEquals(1, carpenterSchema.dependencies[aName]!!.second.size)
assertEquals(iName, carpenterSchema.dependencies[aName]!!.second[0])
// and conversly I should have A listed as a dependent
assertTrue(iName in carpenterSchema.dependsOn)
assertEquals(1, carpenterSchema.dependsOn[iName]!!.size)
assertEquals(aName, carpenterSchema.dependsOn[iName]!![0])
val mc = MetaCarpenter(carpenterSchema)
mc.build()
assertEquals(0, mc.schemas.carpenterSchemas.size)
assertEquals(0, mc.schemas.dependencies.size)
assertEquals(0, mc.schemas.dependsOn.size)
assertEquals(2, mc.objects.size)
assertTrue(aName in mc.objects)
assertTrue(iName in mc.objects)
mc.objects[aName]!!.constructors[0].newInstance(testI)
}
@Test
fun twoInterfacesAndImplementation() {
class A(override val i: Int, override val ii: Int) : I, II
val testI = 69
val testII = 96
val a = A(testI, testII)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
val amqpSchema = obj.envelope.schema.mangleNames(listOf(
classTestName("A"),
"${this.javaClass.`package`.name}.I",
"${this.javaClass.`package`.name}.II"))
val aName = mangleName(classTestName("A"))
val iName = mangleName("${this.javaClass.`package`.name}.I")
val iiName = mangleName("${this.javaClass.`package`.name}.II")
val carpenterSchema = amqpSchema.carpenterSchema()
// there is nothing preventing us from carpenting up the two interfaces so
// our initial list should contain both interface with A being dependent on both
// and each having A as a dependent
assertEquals(2, carpenterSchema.carpenterSchemas.size)
assertNotNull(carpenterSchema.carpenterSchemas.find { it.name == iName })
assertNotNull(carpenterSchema.carpenterSchemas.find { it.name == iiName })
assertNull(carpenterSchema.carpenterSchemas.find { it.name == aName })
assertTrue(iName in carpenterSchema.dependsOn)
assertEquals(1, carpenterSchema.dependsOn[iName]?.size)
assertNotNull(carpenterSchema.dependsOn[iName]?.find({ it == aName }))
assertTrue(iiName in carpenterSchema.dependsOn)
assertEquals(1, carpenterSchema.dependsOn[iiName]?.size)
assertNotNull(carpenterSchema.dependsOn[iiName]?.find { it == aName })
assertTrue(aName in carpenterSchema.dependencies)
assertEquals(2, carpenterSchema.dependencies[aName]!!.second.size)
assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iName })
assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iiName })
val mc = MetaCarpenter(carpenterSchema)
mc.build()
assertEquals(0, mc.schemas.carpenterSchemas.size)
assertEquals(0, mc.schemas.dependencies.size)
assertEquals(0, mc.schemas.dependsOn.size)
assertEquals(3, mc.objects.size)
assertTrue(aName in mc.objects)
assertTrue(iName in mc.objects)
assertTrue(iiName in mc.objects)
}
@Test
fun nestedInterfacesAndImplementation() {
class A(override val i: Int, override val iii: Int) : III
val testI = 7
val testIII = 11
val a = A(testI, testIII)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
val amqpSchema = obj.envelope.schema.mangleNames(listOf(
classTestName("A"),
"${this.javaClass.`package`.name}.I",
"${this.javaClass.`package`.name}.III"))
val aName = mangleName(classTestName("A"))
val iName = mangleName("${this.javaClass.`package`.name}.I")
val iiiName = mangleName("${this.javaClass.`package`.name}.III")
val carpenterSchema = amqpSchema.carpenterSchema()
// Since A depends on III and III extends I we will have to construct them
// in that reverse order (I -> III -> A)
assertEquals(1, carpenterSchema.carpenterSchemas.size)
assertNotNull(carpenterSchema.carpenterSchemas.find { it.name == iName })
assertNull(carpenterSchema.carpenterSchemas.find { it.name == iiiName })
assertNull(carpenterSchema.carpenterSchemas.find { it.name == aName })
// I has III as a direct dependent and A as an indirect one
assertTrue(iName in carpenterSchema.dependsOn)
assertEquals(2, carpenterSchema.dependsOn[iName]?.size)
assertNotNull(carpenterSchema.dependsOn[iName]?.find({ it == iiiName }))
assertNotNull(carpenterSchema.dependsOn[iName]?.find({ it == aName }))
// III has A as a dependent
assertTrue(iiiName in carpenterSchema.dependsOn)
assertEquals(1, carpenterSchema.dependsOn[iiiName]?.size)
assertNotNull(carpenterSchema.dependsOn[iiiName]?.find { it == aName })
// conversly III depends on I
assertTrue(iiiName in carpenterSchema.dependencies)
assertEquals(1, carpenterSchema.dependencies[iiiName]!!.second.size)
assertNotNull(carpenterSchema.dependencies[iiiName]!!.second.find { it == iName })
// and A depends on III and I
assertTrue(aName in carpenterSchema.dependencies)
assertEquals(2, carpenterSchema.dependencies[aName]!!.second.size)
assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iiiName })
assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iName })
val mc = MetaCarpenter(carpenterSchema)
mc.build()
assertEquals(0, mc.schemas.carpenterSchemas.size)
assertEquals(0, mc.schemas.dependencies.size)
assertEquals(0, mc.schemas.dependsOn.size)
assertEquals(3, mc.objects.size)
assertTrue(aName in mc.objects)
assertTrue(iName in mc.objects)
assertTrue(iiiName in mc.objects)
}
}

View File

@ -0,0 +1,95 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.internal.serialization.amqp.CompositeType
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotEquals
class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
@Test
fun twoInts() {
@CordaSerializable
data class A(val a: Int, val b: Int)
val testA = 10
val testB = 20
val a = A(testA, testB)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assert(obj.obj is A)
val amqpObj = obj.obj as A
assertEquals(testA, amqpObj.a)
assertEquals(testB, amqpObj.b)
assertEquals(1, obj.envelope.schema.types.size)
assert(obj.envelope.schema.types[0] is CompositeType)
val amqpSchema = obj.envelope.schema.types[0] as CompositeType
assertEquals(2, amqpSchema.fields.size)
assertEquals("a", amqpSchema.fields[0].name)
assertEquals("int", amqpSchema.fields[0].type)
assertEquals("b", amqpSchema.fields[1].name)
assertEquals("int", amqpSchema.fields[1].type)
val carpenterSchema = CarpenterSchemas.newInstance()
amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true)
assertEquals(1, carpenterSchema.size)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }
assertNotEquals(null, aSchema)
val pinochio = ClassCarpenter().build(aSchema!!)
val p = pinochio.constructors[0].newInstance(testA, testB)
assertEquals(pinochio.getMethod("getA").invoke(p), amqpObj.a)
assertEquals(pinochio.getMethod("getB").invoke(p), amqpObj.b)
}
@Test
fun intAndStr() {
@CordaSerializable
data class A(val a: Int, val b: String)
val testA = 10
val testB = "twenty"
val a = A(testA, testB)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assert(obj.obj is A)
val amqpObj = obj.obj as A
assertEquals(testA, amqpObj.a)
assertEquals(testB, amqpObj.b)
assertEquals(1, obj.envelope.schema.types.size)
assert(obj.envelope.schema.types[0] is CompositeType)
val amqpSchema = obj.envelope.schema.types[0] as CompositeType
assertEquals(2, amqpSchema.fields.size)
assertEquals("a", amqpSchema.fields[0].name)
assertEquals("int", amqpSchema.fields[0].type)
assertEquals("b", amqpSchema.fields[1].name)
assertEquals("string", amqpSchema.fields[1].type)
val carpenterSchema = CarpenterSchemas.newInstance()
amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true)
assertEquals(1, carpenterSchema.size)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }
assertNotEquals(null, aSchema)
val pinochio = ClassCarpenter().build(aSchema!!)
val p = pinochio.constructors[0].newInstance(testA, testB)
assertEquals(pinochio.getMethod("getA").invoke(p), amqpObj.a)
assertEquals(pinochio.getMethod("getB").invoke(p), amqpObj.b)
}
}

View File

@ -0,0 +1,197 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.internal.serialization.amqp.CompositeType
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import org.junit.Test
import kotlin.test.assertEquals
class SingleMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
@Test
fun singleInteger() {
@CordaSerializable
data class A(val a: Int)
val test = 10
val a = A(test)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assert(obj.obj is A)
val amqpObj = obj.obj as A
assertEquals(test, amqpObj.a)
assertEquals(1, obj.envelope.schema.types.size)
assert(obj.envelope.schema.types[0] is CompositeType)
val amqpSchema = obj.envelope.schema.types[0] as CompositeType
assertEquals(1, amqpSchema.fields.size)
assertEquals("a", amqpSchema.fields[0].name)
assertEquals("int", amqpSchema.fields[0].type)
val carpenterSchema = CarpenterSchemas.newInstance()
amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!!
val aBuilder = ClassCarpenter().build(aSchema)
val p = aBuilder.constructors[0].newInstance(test)
assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a)
}
@Test
fun singleString() {
@CordaSerializable
data class A(val a: String)
val test = "ten"
val a = A(test)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assert(obj.obj is A)
val amqpObj = obj.obj as A
assertEquals(test, amqpObj.a)
assertEquals(1, obj.envelope.schema.types.size)
assert(obj.envelope.schema.types[0] is CompositeType)
val amqpSchema = obj.envelope.schema.types[0] as CompositeType
val carpenterSchema = CarpenterSchemas.newInstance()
amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!!
val aBuilder = ClassCarpenter().build(aSchema)
val p = aBuilder.constructors[0].newInstance(test)
assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a)
}
@Test
fun singleLong() {
@CordaSerializable
data class A(val a: Long)
val test = 10L
val a = A(test)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assert(obj.obj is A)
val amqpObj = obj.obj as A
assertEquals(test, amqpObj.a)
assertEquals(1, obj.envelope.schema.types.size)
assert(obj.envelope.schema.types[0] is CompositeType)
val amqpSchema = obj.envelope.schema.types[0] as CompositeType
assertEquals(1, amqpSchema.fields.size)
assertEquals("a", amqpSchema.fields[0].name)
assertEquals("long", amqpSchema.fields[0].type)
val carpenterSchema = CarpenterSchemas.newInstance()
amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!!
val aBuilder = ClassCarpenter().build(aSchema)
val p = aBuilder.constructors[0].newInstance(test)
assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a)
}
@Test
fun singleShort() {
@CordaSerializable
data class A(val a: Short)
val test = 10.toShort()
val a = A(test)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assert(obj.obj is A)
val amqpObj = obj.obj as A
assertEquals(test, amqpObj.a)
assertEquals(1, obj.envelope.schema.types.size)
assert(obj.envelope.schema.types[0] is CompositeType)
val amqpSchema = obj.envelope.schema.types[0] as CompositeType
assertEquals(1, amqpSchema.fields.size)
assertEquals("a", amqpSchema.fields[0].name)
assertEquals("short", amqpSchema.fields[0].type)
val carpenterSchema = CarpenterSchemas.newInstance()
amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!!
val aBuilder = ClassCarpenter().build(aSchema)
val p = aBuilder.constructors[0].newInstance(test)
assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a)
}
@Test
fun singleDouble() {
@CordaSerializable
data class A(val a: Double)
val test = 10.0
val a = A(test)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assert(obj.obj is A)
val amqpObj = obj.obj as A
assertEquals(test, amqpObj.a)
assertEquals(1, obj.envelope.schema.types.size)
assert(obj.envelope.schema.types[0] is CompositeType)
val amqpSchema = obj.envelope.schema.types[0] as CompositeType
assertEquals(1, amqpSchema.fields.size)
assertEquals("a", amqpSchema.fields[0].name)
assertEquals("double", amqpSchema.fields[0].type)
val carpenterSchema = CarpenterSchemas.newInstance()
amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!!
val aBuilder = ClassCarpenter().build(aSchema)
val p = aBuilder.constructors[0].newInstance(test)
assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a)
}
@Test
fun singleFloat() {
@CordaSerializable
data class A(val a: Float)
val test: Float = 10.0F
val a = A(test)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
assert(obj.obj is A)
val amqpObj = obj.obj as A
assertEquals(test, amqpObj.a)
assertEquals(1, obj.envelope.schema.types.size)
assert(obj.envelope.schema.types[0] is CompositeType)
val amqpSchema = obj.envelope.schema.types[0] as CompositeType
assertEquals(1, amqpSchema.fields.size)
assertEquals("a", amqpSchema.fields[0].name)
assertEquals("float", amqpSchema.fields[0].type)
val carpenterSchema = CarpenterSchemas.newInstance()
amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!!
val aBuilder = ClassCarpenter().build(aSchema)
val p = aBuilder.constructors[0].newInstance(test)
assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a)
}
}