diff --git a/core/build.gradle b/core/build.gradle index f0029a6672..131f8acf86 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -76,7 +76,7 @@ dependencies { compile "io.requery:requery-kotlin:$requery_version" // For AMQP serialisation. - compile "org.apache.qpid:proton-j:0.18.0" + compile "org.apache.qpid:proton-j:0.19.0" } configurations { diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowException.kt b/core/src/main/kotlin/net/corda/core/flows/FlowException.kt index d1114d0946..92a8998288 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowException.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowException.kt @@ -1,6 +1,7 @@ package net.corda.core.flows -import net.corda.core.serialization.CordaSerializable +import net.corda.core.utilities.CordaException +import net.corda.core.utilities.CordaRuntimeException /** * Exception which can be thrown by a [FlowLogic] at any point in its logic to unexpectedly bring it to a permanent end. @@ -11,8 +12,7 @@ import net.corda.core.serialization.CordaSerializable * [FlowException] (or a subclass) can be a valid expected response from a flow, particularly ones which act as a service. * It is recommended a [FlowLogic] document the [FlowException] types it can throw. */ -@CordaSerializable -open class FlowException(override val message: String?, override val cause: Throwable?) : Exception() { +open class FlowException(message: String?, cause: Throwable?) : CordaException(message, cause) { constructor(message: String?) : this(message, null) constructor(cause: Throwable?) : this(cause?.toString(), cause) constructor() : this(null, null) @@ -23,5 +23,6 @@ open class FlowException(override val message: String?, override val cause: Thro * that we were not expecting), or the other side had an internal error, or the other side terminated when we * were waiting for a response. */ -@CordaSerializable -class FlowSessionException(message: String) : RuntimeException(message) +class FlowSessionException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) { + constructor(msg: String) : this(msg, null) +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPPrimitiveSerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPPrimitiveSerializer.kt index 2935b19cb9..40f586a88e 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPPrimitiveSerializer.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPPrimitiveSerializer.kt @@ -7,7 +7,7 @@ import java.lang.reflect.Type /** * Serializer / deserializer for native AMQP types (Int, Float, String etc). */ -class AMQPPrimitiveSerializer(clazz: Class<*>) : AMQPSerializer { +class AMQPPrimitiveSerializer(clazz: Class<*>) : AMQPSerializer { override val typeDescriptor: String = SerializerFactory.primitiveTypeName(Primitives.wrap(clazz))!! override val type: Type = clazz @@ -19,5 +19,5 @@ class AMQPPrimitiveSerializer(clazz: Class<*>) : AMQPSerializer { data.putObject(obj) } - override fun readObject(obj: Any, envelope: Envelope, input: DeserializationInput): Any = obj + override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any = obj } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPSerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPSerializer.kt index 20465bb9cb..b2917c39cd 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPSerializer.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPSerializer.kt @@ -6,7 +6,7 @@ import java.lang.reflect.Type /** * Implemented to serialize and deserialize different types of objects to/from AMQP. */ -interface AMQPSerializer { +interface AMQPSerializer { /** * The JVM type this can serialize and deserialize. */ @@ -34,5 +34,5 @@ interface AMQPSerializer { /** * Read the given object from the input. The envelope is provided in case the schema is required. */ - fun readObject(obj: Any, envelope: Envelope, input: DeserializationInput): Any + fun readObject(obj: Any, schema: Schema, input: DeserializationInput): T } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/ArraySerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/ArraySerializer.kt index 2b1c6f5c55..0cf705e16d 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/ArraySerializer.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/ArraySerializer.kt @@ -9,14 +9,12 @@ import java.lang.reflect.Type /** * Serialization / deserialization of arrays. */ -class ArraySerializer(override val type: Type) : AMQPSerializer { - private val typeName = type.typeName +class ArraySerializer(override val type: Type, factory: SerializerFactory) : AMQPSerializer { + override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}" - override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type)}" + internal val elementType: Type = makeElementType() - private val elementType: Type = makeElementType() - - private val typeNotation: TypeNotation = RestrictedType(typeName, null, emptyList(), "list", Descriptor(typeDescriptor, null), emptyList()) + private val typeNotation: TypeNotation = RestrictedType(type.typeName, null, emptyList(), "list", Descriptor(typeDescriptor, null), emptyList()) private fun makeElementType(): Type { return (type as? Class<*>)?.componentType ?: (type as GenericArrayType).genericComponentType @@ -39,8 +37,10 @@ class ArraySerializer(override val type: Type) : AMQPSerializer { } } - override fun readObject(obj: Any, envelope: Envelope, input: DeserializationInput): Any { - return (obj as List<*>).map { input.readObjectOrNull(it, envelope, elementType) }.toArrayOfType(elementType) + override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any { + if (obj is List<*>) { + return obj.map { input.readObjectOrNull(it, schema, elementType) }.toArrayOfType(elementType) + } else throw NotSerializableException("Expected a List but found $obj") } private fun List.toArrayOfType(type: Type): Any { diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/CollectionSerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/CollectionSerializer.kt index 3e2d74002c..0f4421de6c 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/CollectionSerializer.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/CollectionSerializer.kt @@ -12,28 +12,27 @@ import kotlin.collections.Set /** * Serialization / deserialization of predefined set of supported [Collection] types covering mostly [List]s and [Set]s. */ -class CollectionSerializer(val declaredType: ParameterizedType) : AMQPSerializer { +class CollectionSerializer(val declaredType: ParameterizedType, factory: SerializerFactory) : AMQPSerializer { override val type: Type = declaredType as? DeserializedParameterizedType ?: DeserializedParameterizedType.make(declaredType.toString()) - private val typeName = declaredType.toString() - override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type)}" + override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}" companion object { - private val supportedTypes: Map>, (Collection<*>) -> Collection<*>> = mapOf( - Collection::class.java to { coll -> coll }, - List::class.java to { coll -> coll }, - Set::class.java to { coll -> Collections.unmodifiableSet(LinkedHashSet(coll)) }, - SortedSet::class.java to { coll -> Collections.unmodifiableSortedSet(TreeSet(coll)) }, - NavigableSet::class.java to { coll -> Collections.unmodifiableNavigableSet(TreeSet(coll)) } + private val supportedTypes: Map>, (List<*>) -> Collection<*>> = mapOf( + Collection::class.java to { list -> Collections.unmodifiableCollection(list) }, + List::class.java to { list -> Collections.unmodifiableList(list) }, + Set::class.java to { list -> Collections.unmodifiableSet(LinkedHashSet(list)) }, + SortedSet::class.java to { list -> Collections.unmodifiableSortedSet(TreeSet(list)) }, + NavigableSet::class.java to { list -> Collections.unmodifiableNavigableSet(TreeSet(list)) } ) + + private fun findConcreteType(clazz: Class<*>): (List<*>) -> Collection<*> { + return supportedTypes[clazz] ?: throw NotSerializableException("Unsupported collection type $clazz.") + } } - private val concreteBuilder: (Collection<*>) -> Collection<*> = findConcreteType(declaredType.rawType as Class<*>) + private val concreteBuilder: (List<*>) -> Collection<*> = findConcreteType(declaredType.rawType as Class<*>) - private fun findConcreteType(clazz: Class<*>): (Collection<*>) -> Collection<*> { - return supportedTypes[clazz] ?: throw NotSerializableException("Unsupported map type $clazz.") - } - - private val typeNotation: TypeNotation = RestrictedType(typeName, null, emptyList(), "list", Descriptor(typeDescriptor, null), emptyList()) + private val typeNotation: TypeNotation = RestrictedType(declaredType.toString(), null, emptyList(), "list", Descriptor(typeDescriptor, null), emptyList()) override fun writeClassInfo(output: SerializationOutput) { if (output.writeTypeNotations(typeNotation)) { @@ -52,8 +51,8 @@ class CollectionSerializer(val declaredType: ParameterizedType) : AMQPSerializer } } - override fun readObject(obj: Any, envelope: Envelope, input: DeserializationInput): Any { + override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any { // TODO: Can we verify the entries in the list? - return concreteBuilder((obj as List<*>).map { input.readObjectOrNull(it, envelope, declaredType.actualTypeArguments[0]) }) + return concreteBuilder((obj as List<*>).map { input.readObjectOrNull(it, schema, declaredType.actualTypeArguments[0]) }) } } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/CustomSerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/CustomSerializer.kt new file mode 100644 index 0000000000..e88230de3d --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/CustomSerializer.kt @@ -0,0 +1,105 @@ +package net.corda.core.serialization.amqp + +import org.apache.qpid.proton.codec.Data +import java.lang.reflect.Type + +/** + * Base class for serializers of core platform types that do not conform to the usual serialization rules and thus + * cannot be automatically serialized. + */ +abstract class CustomSerializer : AMQPSerializer { + /** + * This is a collection of custom serializers that this custom serializer depends on. e.g. for proxy objects + * that refer to arrays of types etc. + */ + abstract val additionalSerializers: Iterable> + + abstract fun isSerializerFor(clazz: Class<*>): Boolean + protected abstract val descriptor: Descriptor + /** + * This exists purely for documentation and cross-platform purposes. It is not used by our serialization / deserialization + * code path. + */ + abstract val schemaForDocumentation: Schema + + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) { + data.withDescribed(descriptor) { + @Suppress("UNCHECKED_CAST") + writeDescribedObject(obj as T, data, type, output) + } + } + + abstract fun writeDescribedObject(obj: T, data: Data, type: Type, output: SerializationOutput) + + /** + * Additional base features for a custom serializer that is a particular class. + */ + abstract class Is(protected val clazz: Class) : CustomSerializer() { + override fun isSerializerFor(clazz: Class<*>): Boolean = clazz == this.clazz + override val type: Type get() = clazz + override val typeDescriptor: String = "$DESCRIPTOR_DOMAIN:${clazz.name}" + override fun writeClassInfo(output: SerializationOutput) {} + override val descriptor: Descriptor = Descriptor(typeDescriptor) + } + + /** + * Additional base features for a custom serializer for all implementations of a particular interface or super class. + */ + abstract class Implements(protected val clazz: Class) : CustomSerializer() { + override fun isSerializerFor(clazz: Class<*>): Boolean = this.clazz.isAssignableFrom(clazz) + override val type: Type get() = clazz + override val typeDescriptor: String = "$DESCRIPTOR_DOMAIN:${clazz.name}" + override fun writeClassInfo(output: SerializationOutput) {} + override val descriptor: Descriptor = Descriptor(typeDescriptor) + } + + /** + * Addition base features over and above [Implements] or [Is] custom serializer for when the serialize form should be + * the serialized form of a proxy class, and the object can be re-created from that proxy on deserialization. + * + * The proxy class must use only types which are either native AMQP or other types for which there are pre-registered + * custom serializers. + */ + abstract class Proxy(protected val clazz: Class, + protected val proxyClass: Class

, + protected val factory: SerializerFactory, + val withInheritance: Boolean = true) : CustomSerializer() { + override fun isSerializerFor(clazz: Class<*>): Boolean = if (withInheritance) this.clazz.isAssignableFrom(clazz) else this.clazz == clazz + override val type: Type get() = clazz + override val typeDescriptor: String = "$DESCRIPTOR_DOMAIN:${clazz.name}" + override fun writeClassInfo(output: SerializationOutput) {} + override val descriptor: Descriptor = Descriptor(typeDescriptor) + + private val proxySerializer: ObjectSerializer by lazy { ObjectSerializer(proxyClass, factory) } + + override val schemaForDocumentation: Schema by lazy { + val typeNotations = mutableSetOf(CompositeType(type.typeName, null, emptyList(), descriptor, (proxySerializer.typeNotation as CompositeType).fields)) + for (additional in additionalSerializers) { + typeNotations.addAll(additional.schemaForDocumentation.types) + } + Schema(typeNotations.toList()) + } + + /** + * Implement these two methods. + */ + protected abstract fun toProxy(obj: T): P + + protected abstract fun fromProxy(proxy: P): T + + override fun writeDescribedObject(obj: T, data: Data, type: Type, output: SerializationOutput) { + val proxy = toProxy(obj) + data.withList { + for (property in proxySerializer.propertySerializers) { + property.writeProperty(proxy, this, output) + } + } + } + + override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): T { + @Suppress("UNCHECKED_CAST") + val proxy = proxySerializer.readObject(obj, schema, input) as P + return fromProxy(proxy) + } + } +} diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializationInput.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializationInput.kt index b47d75b8bc..ccbe1fac20 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializationInput.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializationInput.kt @@ -15,7 +15,7 @@ import java.util.* * @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple * instances and threads. */ -class DeserializationInput(private val serializerFactory: SerializerFactory = SerializerFactory()) { +class DeserializationInput(internal val serializerFactory: SerializerFactory = SerializerFactory()) { // TODO: we're not supporting object refs yet private val objectHistory: MutableList = ArrayList() @@ -41,7 +41,7 @@ class DeserializationInput(private val serializerFactory: SerializerFactory = Se } val envelope = Envelope.get(data) // Now pick out the obj and schema from the envelope. - return clazz.cast(readObjectOrNull(envelope.obj, envelope, clazz)) + return clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz)) } catch(nse: NotSerializableException) { throw nse } catch(t: Throwable) { @@ -51,20 +51,21 @@ class DeserializationInput(private val serializerFactory: SerializerFactory = Se } } - internal fun readObjectOrNull(obj: Any?, envelope: Envelope, type: Type): Any? { + internal fun readObjectOrNull(obj: Any?, schema: Schema, type: Type): Any? { if (obj == null) { return null } else { - return readObject(obj, envelope, type) + return readObject(obj, schema, type) } } - internal fun readObject(obj: Any, envelope: Envelope, type: Type): Any { + internal fun readObject(obj: Any, schema: Schema, type: Type): Any { if (obj is DescribedType) { // Look up serializer in factory by descriptor - val serializer = serializerFactory.get(obj.descriptor, envelope) - if (serializer.type != type && !serializer.type.isSubClassOf(type)) throw NotSerializableException("Described type with descriptor ${obj.descriptor} was expected to be of type $type") - return serializer.readObject(obj.described, envelope, this) + val serializer = serializerFactory.get(obj.descriptor, schema) + if (serializer.type != type && !serializer.type.isSubClassOf(type)) + throw NotSerializableException("Described type with descriptor ${obj.descriptor} was expected to be of type $type") + return serializer.readObject(obj.described, schema, this) } else { return obj } diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializedParameterizedType.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializedParameterizedType.kt index 2cd0ae1298..9a0809d18d 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializedParameterizedType.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializedParameterizedType.kt @@ -119,7 +119,7 @@ class DeserializedParameterizedType(private val rawType: Class<*>, private val p private fun makeType(typeName: String, cl: ClassLoader): Type { // Not generic - return if (typeName == "*") SerializerFactory.AnyType else Class.forName(typeName, false, cl) + return if (typeName == "?") SerializerFactory.AnyType else Class.forName(typeName, false, cl) } private fun makeParameterizedType(rawTypeName: String, args: MutableList, cl: ClassLoader): Type { diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/MapSerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/MapSerializer.kt index 2ea61c6598..7991648f1a 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/MapSerializer.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/MapSerializer.kt @@ -13,10 +13,9 @@ import kotlin.collections.map /** * Serialization / deserialization of certain supported [Map] types. */ -class MapSerializer(val declaredType: ParameterizedType) : AMQPSerializer { +class MapSerializer(val declaredType: ParameterizedType, factory: SerializerFactory) : AMQPSerializer { override val type: Type = declaredType as? DeserializedParameterizedType ?: DeserializedParameterizedType.make(declaredType.toString()) - private val typeName = declaredType.toString() - override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type)}" + override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}" companion object { private val supportedTypes: Map>, (Map<*, *>) -> Map<*, *>> = mapOf( @@ -24,15 +23,15 @@ class MapSerializer(val declaredType: ParameterizedType) : AMQPSerializer { SortedMap::class.java to { map -> Collections.unmodifiableSortedMap(TreeMap(map)) }, NavigableMap::class.java to { map -> Collections.unmodifiableNavigableMap(TreeMap(map)) } ) + + private fun findConcreteType(clazz: Class<*>): (Map<*, *>) -> Map<*, *> { + return supportedTypes[clazz] ?: throw NotSerializableException("Unsupported map type $clazz.") + } } private val concreteBuilder: (Map<*, *>) -> Map<*, *> = findConcreteType(declaredType.rawType as Class<*>) - private fun findConcreteType(clazz: Class<*>): (Map<*, *>) -> Map<*, *> { - return supportedTypes[clazz] ?: throw NotSerializableException("Unsupported map type $clazz.") - } - - private val typeNotation: TypeNotation = RestrictedType(typeName, null, emptyList(), "map", Descriptor(typeDescriptor, null), emptyList()) + private val typeNotation: TypeNotation = RestrictedType(declaredType.toString(), null, emptyList(), "map", Descriptor(typeDescriptor, null), emptyList()) override fun writeClassInfo(output: SerializationOutput) { if (output.writeTypeNotations(typeNotation)) { @@ -56,11 +55,13 @@ class MapSerializer(val declaredType: ParameterizedType) : AMQPSerializer { } } - override fun readObject(obj: Any, envelope: Envelope, input: DeserializationInput): Any { + override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any { // TODO: General generics question. Do we need to validate that entries in Maps and Collections match the generic type? Is it a security hole? - val entries: Iterable> = (obj as Map<*, *>).map { readEntry(envelope, input, it) } + val entries: Iterable> = (obj as Map<*, *>).map { readEntry(schema, input, it) } return concreteBuilder(entries.toMap()) } - private fun readEntry(envelope: Envelope, input: DeserializationInput, entry: Map.Entry) = input.readObjectOrNull(entry.key, envelope, declaredType.actualTypeArguments[0]) to input.readObjectOrNull(entry.value, envelope, declaredType.actualTypeArguments[1]) + private fun readEntry(schema: Schema, input: DeserializationInput, entry: Map.Entry) = + input.readObjectOrNull(entry.key, schema, declaredType.actualTypeArguments[0]) to + input.readObjectOrNull(entry.value, schema, declaredType.actualTypeArguments[1]) } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/ObjectSerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/ObjectSerializer.kt index 2ccfad81d6..130d50d7a3 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/ObjectSerializer.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/ObjectSerializer.kt @@ -10,26 +10,30 @@ import kotlin.reflect.jvm.javaConstructor /** * Responsible for serializing and deserializing a regular object instance via a series of properties (matched with a constructor). */ -class ObjectSerializer(val clazz: Class<*>) : AMQPSerializer { +class ObjectSerializer(val clazz: Class<*>, factory: SerializerFactory) : AMQPSerializer { override val type: Type get() = clazz private val javaConstructor: Constructor? - private val propertySerializers: Collection + internal val propertySerializers: Collection init { val kotlinConstructor = constructorForDeserialization(clazz) javaConstructor = kotlinConstructor?.javaConstructor - propertySerializers = propertiesForSerialization(kotlinConstructor, clazz) + propertySerializers = propertiesForSerialization(kotlinConstructor, clazz, factory) } private val typeName = clazz.name - override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type)}" + override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}" private val interfaces = interfacesForSerialization(clazz) // TODO maybe this proves too much and we need annotations to restrict. - private val typeNotation: TypeNotation = CompositeType(typeName, null, generateProvides(), Descriptor(typeDescriptor, null), generateFields()) + internal val typeNotation: TypeNotation = CompositeType(typeName, null, generateProvides(), Descriptor(typeDescriptor, null), generateFields()) override fun writeClassInfo(output: SerializationOutput) { - output.writeTypeNotations(typeNotation) - for (iface in interfaces) { - output.requireSerializer(iface) + if (output.writeTypeNotations(typeNotation)) { + for (iface in interfaces) { + output.requireSerializer(iface) + } + for (property in propertySerializers) { + property.writeClassInfo(output) + } } } @@ -45,13 +49,13 @@ class ObjectSerializer(val clazz: Class<*>) : AMQPSerializer { } } - override fun readObject(obj: Any, envelope: Envelope, input: DeserializationInput): Any { + override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any { if (obj is UnsignedInteger) { // TODO: Object refs TODO("not implemented") //To change body of created functions use File | Settings | File Templates. } else if (obj is List<*>) { if (obj.size > propertySerializers.size) throw NotSerializableException("Too many properties in described type $typeName") - val params = obj.zip(propertySerializers).map { it.second.readProperty(it.first, envelope, input) } + val params = obj.zip(propertySerializers).map { it.second.readProperty(it.first, schema, input) } return construct(params) } else throw NotSerializableException("Body of described type is unexpected $obj") } diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/PropertySerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/PropertySerializer.kt index 50cb6c5581..6e26441633 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/PropertySerializer.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/PropertySerializer.kt @@ -9,8 +9,9 @@ import kotlin.reflect.jvm.javaGetter * Base class for serialization of a property of an object. */ sealed class PropertySerializer(val name: String, val readMethod: Method) { + abstract fun writeClassInfo(output: SerializationOutput) abstract fun writeProperty(obj: Any?, data: Data, output: SerializationOutput) - abstract fun readProperty(obj: Any?, envelope: Envelope, input: DeserializationInput): Any? + abstract fun readProperty(obj: Any?, schema: Schema, input: DeserializationInput): Any? val type: String = generateType() val requires: List = generateRequires() @@ -53,13 +54,13 @@ sealed class PropertySerializer(val name: String, val readMethod: Method) { } companion object { - fun make(name: String, readMethod: Method): PropertySerializer { + fun make(name: String, readMethod: Method, factory: SerializerFactory): PropertySerializer { val type = readMethod.genericReturnType if (SerializerFactory.isPrimitive(type)) { // This is a little inefficient for performance since it does a runtime check of type. We could do build time check with lots of subclasses here. return AMQPPrimitivePropertySerializer(name, readMethod) } else { - return DescribedTypePropertySerializer(name, readMethod) + return DescribedTypePropertySerializer(name, readMethod) { factory.get(null, type) } } } } @@ -67,9 +68,16 @@ sealed class PropertySerializer(val name: String, val readMethod: Method) { /** * A property serializer for a complex type (another object). */ - class DescribedTypePropertySerializer(name: String, readMethod: Method) : PropertySerializer(name, readMethod) { - override fun readProperty(obj: Any?, envelope: Envelope, input: DeserializationInput): Any? { - return input.readObjectOrNull(obj, envelope, readMethod.genericReturnType) + class DescribedTypePropertySerializer(name: String, readMethod: Method, private val lazyTypeSerializer: () -> AMQPSerializer) : PropertySerializer(name, readMethod) { + // This is lazy so we don't get an infinite loop when a method returns an instance of the class. + private val typeSerializer: AMQPSerializer by lazy { lazyTypeSerializer() } + + override fun writeClassInfo(output: SerializationOutput) { + typeSerializer.writeClassInfo(output) + } + + override fun readProperty(obj: Any?, schema: Schema, input: DeserializationInput): Any? { + return input.readObjectOrNull(obj, schema, readMethod.genericReturnType) } override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput) { @@ -81,7 +89,9 @@ sealed class PropertySerializer(val name: String, val readMethod: Method) { * A property serializer for an AMQP primitive type (Int, String, etc). */ class AMQPPrimitivePropertySerializer(name: String, readMethod: Method) : PropertySerializer(name, readMethod) { - override fun readProperty(obj: Any?, envelope: Envelope, input: DeserializationInput): Any? { + override fun writeClassInfo(output: SerializationOutput) {} + + override fun readProperty(obj: Any?, schema: Schema, input: DeserializationInput): Any? { return obj } diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/Schema.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/Schema.kt index 64a28a7aae..5c627cc943 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/Schema.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/Schema.kt @@ -87,7 +87,7 @@ data class Schema(val types: List) : DescribedType { override fun toString(): String = types.joinToString("\n") } -data class Descriptor(val name: String?, val code: UnsignedLong?) : DescribedType { +data class Descriptor(val name: String?, val code: UnsignedLong? = null) : DescribedType { companion object : DescribedTypeConstructor { val DESCRIPTOR = UnsignedLong(3L or DESCRIPTOR_TOP_32BITS) @@ -320,9 +320,9 @@ private val ANY_TYPE_HASH: String = "Any type = true" * different. */ // TODO: write tests -internal fun fingerprintForType(type: Type): String = Base58.encode(fingerprintForType(type, HashSet(), Hashing.murmur3_128().newHasher()).hash().asBytes()) +internal fun fingerprintForType(type: Type, factory: SerializerFactory): String = Base58.encode(fingerprintForType(type, HashSet(), Hashing.murmur3_128().newHasher(), factory).hash().asBytes()) -private fun fingerprintForType(type: Type, alreadySeen: MutableSet, hasher: Hasher): Hasher { +private fun fingerprintForType(type: Type, alreadySeen: MutableSet, hasher: Hasher, factory: SerializerFactory): Hasher { return if (type in alreadySeen) { hasher.putUnencodedChars(ALREADY_SEEN_HASH) } else { @@ -331,25 +331,31 @@ private fun fingerprintForType(type: Type, alreadySeen: MutableSet, hasher hasher.putUnencodedChars(ANY_TYPE_HASH) } else if (type is Class<*>) { if (type.isArray) { - fingerprintForType(type.componentType, alreadySeen, hasher).putUnencodedChars(ARRAY_HASH) + fingerprintForType(type.componentType, alreadySeen, hasher, factory).putUnencodedChars(ARRAY_HASH) } else if (SerializerFactory.isPrimitive(type)) { hasher.putUnencodedChars(type.name) } else if (Collection::class.java.isAssignableFrom(type) || Map::class.java.isAssignableFrom(type)) { hasher.putUnencodedChars(type.name) } else { - // Hash the class + properties + interfaces - propertiesForSerialization(constructorForDeserialization(type), type).fold(hasher.putUnencodedChars(type.name)) { orig, param -> - fingerprintForType(param.readMethod.genericReturnType, alreadySeen, orig).putUnencodedChars(param.name).putUnencodedChars(if (param.mandatory) NOT_NULLABLE_HASH else NULLABLE_HASH) + // Need to check if a custom serializer is applicable + val customSerializer = factory.findCustomSerializer(type) + if (customSerializer == null) { + // Hash the class + properties + interfaces + propertiesForSerialization(constructorForDeserialization(type), type, factory).fold(hasher.putUnencodedChars(type.name)) { orig, param -> + fingerprintForType(param.readMethod.genericReturnType, alreadySeen, orig, factory).putUnencodedChars(param.name).putUnencodedChars(if (param.mandatory) NOT_NULLABLE_HASH else NULLABLE_HASH) + } + interfacesForSerialization(type).map { fingerprintForType(it, alreadySeen, hasher, factory) } + hasher + } else { + hasher.putUnencodedChars(customSerializer.typeDescriptor) } - interfacesForSerialization(type).map { fingerprintForType(it, alreadySeen, hasher) } - hasher } } else if (type is ParameterizedType) { // Hash the rawType + params - type.actualTypeArguments.fold(fingerprintForType(type.rawType, alreadySeen, hasher)) { orig, paramType -> fingerprintForType(paramType, alreadySeen, orig) } + type.actualTypeArguments.fold(fingerprintForType(type.rawType, alreadySeen, hasher, factory)) { orig, paramType -> fingerprintForType(paramType, alreadySeen, orig, factory) } } else if (type is GenericArrayType) { // Hash the element type + some array hash - fingerprintForType(type.genericComponentType, alreadySeen, hasher).putUnencodedChars(ARRAY_HASH) + fingerprintForType(type.genericComponentType, alreadySeen, hasher, factory).putUnencodedChars(ARRAY_HASH) } else { throw NotSerializableException("Don't know how to hash $type") } diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationHelper.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationHelper.kt index 107769cde7..85082544a4 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationHelper.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationHelper.kt @@ -1,14 +1,16 @@ package net.corda.core.serialization.amqp +import com.google.common.reflect.TypeToken import org.apache.qpid.proton.codec.Data import java.beans.Introspector -import java.beans.PropertyDescriptor import java.io.NotSerializableException +import java.lang.reflect.Method import java.lang.reflect.Modifier import java.lang.reflect.ParameterizedType import java.lang.reflect.Type import kotlin.reflect.KClass import kotlin.reflect.KFunction +import kotlin.reflect.KParameter import kotlin.reflect.full.findAnnotation import kotlin.reflect.full.primaryConstructor import kotlin.reflect.jvm.javaType @@ -58,24 +60,26 @@ internal fun constructorForDeserialization(clazz: Class): KFunction * Note, you will need any Java classes to be compiled with the `-parameters` option to ensure constructor parameters have * names accessible via reflection. */ -internal fun propertiesForSerialization(kotlinConstructor: KFunction?, clazz: Class<*>): Collection { - return if (kotlinConstructor != null) propertiesForSerialization(kotlinConstructor) else propertiesForSerialization(clazz) +internal fun propertiesForSerialization(kotlinConstructor: KFunction?, clazz: Class<*>, factory: SerializerFactory): Collection { + return if (kotlinConstructor != null) propertiesForSerialization(kotlinConstructor, factory) else propertiesForSerialization(clazz, factory) } private fun isConcrete(clazz: Class<*>): Boolean = !(clazz.isInterface || Modifier.isAbstract(clazz.modifiers)) -private fun propertiesForSerialization(kotlinConstructor: KFunction): Collection { +private fun propertiesForSerialization(kotlinConstructor: KFunction, factory: SerializerFactory): Collection { val clazz = (kotlinConstructor.returnType.classifier as KClass<*>).javaObjectType // Kotlin reflection doesn't work with Java getters the way you might expect, so we drop back to good ol' beans. - val properties: Map = Introspector.getBeanInfo(clazz).propertyDescriptors.filter { it.name != "class" }.groupBy { it.name }.mapValues { it.value[0] } + val properties = Introspector.getBeanInfo(clazz).propertyDescriptors.filter { it.name != "class" }.groupBy { it.name }.mapValues { it.value[0] } val rc: MutableList = ArrayList(kotlinConstructor.parameters.size) for (param in kotlinConstructor.parameters) { val name = param.name ?: throw NotSerializableException("Constructor parameter of $clazz has no name.") - val matchingProperty = properties[name] ?: throw NotSerializableException("No property matching constructor parameter named $name of $clazz. If using Java, check that you have the -parameters option specified in the Java compiler.") + val matchingProperty = properties[name] ?: throw NotSerializableException("No property matching constructor parameter named $name of $clazz." + + " If using Java, check that you have the -parameters option specified in the Java compiler.") // Check that the method has a getter in java. - val getter = matchingProperty.readMethod ?: throw NotSerializableException("Property has no getter method for $name of $clazz. If using Java and the parameter name looks anonymous, check that you have the -parameters option specified in the Java compiler.") - if (getter.genericReturnType == param.type.javaType) { - rc += PropertySerializer.make(name, getter) + val getter = matchingProperty.readMethod ?: throw NotSerializableException("Property has no getter method for $name of $clazz." + + " If using Java and the parameter name looks anonymous, check that you have the -parameters option specified in the Java compiler.") + if (constructorParamTakesReturnTypeOfGetter(getter, param)) { + rc += PropertySerializer.make(name, getter, factory) } else { throw NotSerializableException("Property type ${getter.genericReturnType} for $name of $clazz differs from constructor parameter type ${param.type.javaType}") } @@ -83,14 +87,16 @@ private fun propertiesForSerialization(kotlinConstructor: KFunction return rc } -private fun propertiesForSerialization(clazz: Class<*>): Collection { +private fun constructorParamTakesReturnTypeOfGetter(getter: Method, param: KParameter): Boolean = TypeToken.of(param.type.javaType).isSupertypeOf(getter.genericReturnType) + +private fun propertiesForSerialization(clazz: Class<*>, factory: SerializerFactory): Collection { // Kotlin reflection doesn't work with Java getters the way you might expect, so we drop back to good ol' beans. val properties = Introspector.getBeanInfo(clazz).propertyDescriptors.filter { it.name != "class" }.sortedBy { it.name } val rc: MutableList = ArrayList(properties.size) for (property in properties) { // Check that the method has a getter in java. val getter = property.readMethod ?: throw NotSerializableException("Property has no getter method for ${property.name} of $clazz.") - rc += PropertySerializer.make(property.name, getter) + rc += PropertySerializer.make(property.name, getter, factory) } return rc } @@ -104,6 +110,7 @@ internal fun interfacesForSerialization(clazz: Class<*>): List { private fun exploreType(type: Type?, interfaces: MutableSet) { val clazz = (type as? Class<*>) ?: (type as? ParameterizedType)?.rawType as? Class<*> if (clazz != null) { + if (clazz.isInterface) interfaces += clazz for (newInterface in clazz.genericInterfaces) { if (newInterface !in interfaces) { interfaces += newInterface diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationOutput.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationOutput.kt index f440d62c2a..3cbfad41ba 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationOutput.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationOutput.kt @@ -14,10 +14,10 @@ import kotlin.collections.LinkedHashSet * @param serializerFactory This is the factory for [AMQPSerializer] instances and can be shared across multiple * instances and threads. */ -class SerializationOutput(private val serializerFactory: SerializerFactory = SerializerFactory()) { +open class SerializationOutput(internal val serializerFactory: SerializerFactory = SerializerFactory()) { // TODO: we're not supporting object refs yet private val objectHistory: MutableMap = IdentityHashMap() - private val serializerHistory: MutableSet = LinkedHashSet() + private val serializerHistory: MutableSet> = LinkedHashSet() private val schemaHistory: MutableSet = LinkedHashSet() /** @@ -64,19 +64,21 @@ class SerializationOutput(private val serializerFactory: SerializerFactory = Ser internal fun writeObject(obj: Any, data: Data, type: Type) { val serializer = serializerFactory.get(obj.javaClass, type) if (serializer !in serializerHistory) { + serializerHistory.add(serializer) serializer.writeClassInfo(this) } serializer.writeObject(obj, data, type, this) } - internal fun writeTypeNotations(vararg typeNotation: TypeNotation): Boolean { + open internal fun writeTypeNotations(vararg typeNotation: TypeNotation): Boolean { return schemaHistory.addAll(typeNotation) } - internal fun requireSerializer(type: Type) { - if (type != SerializerFactory.AnyType) { + open internal fun requireSerializer(type: Type) { + if (type != SerializerFactory.AnyType && type != Object::class.java) { val serializer = serializerFactory.get(null, type) if (serializer !in serializerHistory) { + serializerHistory.add(serializer) serializer.writeClassInfo(this) } } diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializerFactory.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializerFactory.kt index 1456c9a7ca..207f0979df 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializerFactory.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializerFactory.kt @@ -10,18 +10,19 @@ import java.io.NotSerializableException import java.lang.reflect.GenericArrayType import java.lang.reflect.ParameterizedType import java.lang.reflect.Type +import java.lang.reflect.WildcardType import java.util.* import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.CopyOnWriteArrayList import javax.annotation.concurrent.ThreadSafe /** * Factory of serializers designed to be shared across threads and invocations. */ +// TODO: enums // TODO: object references // TODO: class references? (e.g. cheat with repeated descriptors using a long encoding, like object ref proposal) // TODO: Inner classes etc -// TODO: support for custom serialisation of core types (of e.g. PublicKey, Throwables) -// TODO: exclude schemas for core types that don't need custom serializers that everyone already knows the schema for. // TODO: support for intern-ing of deserialized objects for some core types (e.g. PublicKey) for memory efficiency // TODO: maybe support for caching of serialized form of some core types for performance // TODO: profile for performance in general @@ -30,10 +31,13 @@ import javax.annotation.concurrent.ThreadSafe // TODO: incorporate the class carpenter for classes not on the classpath. // TODO: apply class loader logic and an "app context" throughout this code. // TODO: schema evolution solution when the fingerprints do not line up. +// TODO: allow definition of well known types that are left out of the schema. +// TODO: automatically support byte[] without having to wrap in [Binary]. @ThreadSafe class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { - private val serializersByType = ConcurrentHashMap() - private val serializersByDescriptor = ConcurrentHashMap() + private val serializersByType = ConcurrentHashMap>() + private val serializersByDescriptor = ConcurrentHashMap>() + private val customSerializers = CopyOnWriteArrayList>() /** * Look up, and manufacture if necessary, a serializer for the given type. @@ -42,7 +46,7 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { * restricted type processing). */ @Throws(NotSerializableException::class) - fun get(actualType: Class<*>?, declaredType: Type): AMQPSerializer { + fun get(actualType: Class<*>?, declaredType: Type): AMQPSerializer { if (declaredType is ParameterizedType) { return serializersByType.computeIfAbsent(declaredType) { // We allow only Collection and Map. @@ -50,7 +54,7 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { if (rawType is Class<*>) { checkParameterisedTypesConcrete(declaredType.actualTypeArguments) if (Collection::class.java.isAssignableFrom(rawType)) { - CollectionSerializer(declaredType) + CollectionSerializer(declaredType, this) } else if (Map::class.java.isAssignableFrom(rawType)) { makeMapSerializer(declaredType) } else { @@ -63,27 +67,44 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { } else if (declaredType is Class<*>) { // Simple classes allowed if (Collection::class.java.isAssignableFrom(declaredType)) { - return serializersByType.computeIfAbsent(declaredType) { CollectionSerializer(DeserializedParameterizedType(declaredType, arrayOf(AnyType), null)) } + return serializersByType.computeIfAbsent(declaredType) { CollectionSerializer(DeserializedParameterizedType(declaredType, arrayOf(AnyType), null), this) } } else if (Map::class.java.isAssignableFrom(declaredType)) { return serializersByType.computeIfAbsent(declaredType) { makeMapSerializer(DeserializedParameterizedType(declaredType, arrayOf(AnyType, AnyType), null)) } } else { return makeClassSerializer(actualType ?: declaredType) } } else if (declaredType is GenericArrayType) { - return serializersByType.computeIfAbsent(declaredType) { ArraySerializer(declaredType) } + return serializersByType.computeIfAbsent(declaredType) { ArraySerializer(declaredType, this) } } else { throw NotSerializableException("Declared types of $declaredType are not supported.") } } + /** + * Lookup and manufacture a serializer for the given AMQP type descriptor, assuming we also have the necessary types + * contained in the [Schema]. + */ @Throws(NotSerializableException::class) - fun get(typeDescriptor: Any, envelope: Envelope): AMQPSerializer { + fun get(typeDescriptor: Any, schema: Schema): AMQPSerializer { return serializersByDescriptor[typeDescriptor] ?: { - processSchema(envelope.schema) + processSchema(schema) serializersByDescriptor[typeDescriptor] ?: throw NotSerializableException("Could not find type matching descriptor $typeDescriptor.") }() } + /** + * TODO: Add docs + */ + fun register(customSerializer: CustomSerializer) { + if (!serializersByDescriptor.containsKey(customSerializer.typeDescriptor)) { + customSerializers += customSerializer + serializersByDescriptor[customSerializer.typeDescriptor] = customSerializer + for (additional in customSerializer.additionalSerializers) { + register(additional) + } + } + } + private fun processSchema(schema: Schema) { for (typeNotation in schema.types) { processSchemaEntry(typeNotation) @@ -99,7 +120,14 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { private fun restrictedTypeForName(name: String): Type { return if (name.endsWith("[]")) { - DeserializedGenericArrayType(restrictedTypeForName(name.substring(0, name.lastIndex - 1))) + val elementType = restrictedTypeForName(name.substring(0, name.lastIndex - 1)) + if (elementType is ParameterizedType || elementType is GenericArrayType) { + DeserializedGenericArrayType(elementType) + } else if (elementType is Class<*>) { + java.lang.reflect.Array.newInstance(elementType, 0).javaClass + } else { + throw NotSerializableException("Not able to deserialize array type: $name") + } } else { DeserializedParameterizedType.make(name) } @@ -134,32 +162,52 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { } } - private fun makeClassSerializer(clazz: Class<*>): AMQPSerializer { + private fun makeClassSerializer(clazz: Class<*>): AMQPSerializer { return serializersByType.computeIfAbsent(clazz) { - if (clazz.isArray) { - whitelisted(clazz.componentType) - ArraySerializer(clazz) - } else if (isPrimitive(clazz)) { + if (isPrimitive(clazz)) { AMQPPrimitiveSerializer(clazz) } else { - whitelisted(clazz) - ObjectSerializer(clazz) + findCustomSerializer(clazz) ?: { + if (clazz.isArray) { + whitelisted(clazz.componentType) + ArraySerializer(clazz, this) + } else { + whitelisted(clazz) + ObjectSerializer(clazz, this) + } + }() } } } + internal fun findCustomSerializer(clazz: Class<*>): AMQPSerializer? { + for (customSerializer in customSerializers) { + if (customSerializer.isSerializerFor(clazz)) { + return customSerializer + } + } + return null + } + private fun whitelisted(clazz: Class<*>): Boolean { - if (whitelist.hasListed(clazz) || clazz.isAnnotationPresent(CordaSerializable::class.java)) { + if (whitelist.hasListed(clazz) || hasAnnotationInHierarchy(clazz)) { return true } else { throw NotSerializableException("Class $clazz is not on the whitelist or annotated with @CordaSerializable.") } } - private fun makeMapSerializer(declaredType: ParameterizedType): AMQPSerializer { + // Recursively check the class, interfaces and superclasses for our annotation. + internal fun hasAnnotationInHierarchy(type: Class<*>): Boolean { + return type.isAnnotationPresent(CordaSerializable::class.java) || + type.interfaces.any { it.isAnnotationPresent(CordaSerializable::class.java) || hasAnnotationInHierarchy(it) } + || (type.superclass != null && hasAnnotationInHierarchy(type.superclass)) + } + + private fun makeMapSerializer(declaredType: ParameterizedType): AMQPSerializer { val rawType = declaredType.rawType as Class<*> rawType.checkNotUnorderedHashMap() - return MapSerializer(declaredType) + return MapSerializer(declaredType, this) } companion object { @@ -185,12 +233,17 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { Char::class.java to "char", Date::class.java to "timestamp", UUID::class.java to "uuid", - ByteArray::class.java to "binary", + Binary::class.java to "binary", String::class.java to "string", Symbol::class.java to "symbol") } - object AnyType : Type { - override fun toString(): String = "*" + object AnyType : WildcardType { + override fun getUpperBounds(): Array = arrayOf(Object::class.java) + + override fun getLowerBounds(): Array = emptyArray() + + override fun toString(): String = "?" } } + diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/PublicKeySerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/PublicKeySerializer.kt new file mode 100644 index 0000000000..46536a1bed --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/PublicKeySerializer.kt @@ -0,0 +1,24 @@ +package net.corda.core.serialization.amqp.custom + +import net.corda.core.crypto.Crypto +import net.corda.core.serialization.amqp.* +import org.apache.qpid.proton.amqp.Binary +import org.apache.qpid.proton.codec.Data +import java.lang.reflect.Type +import java.security.PublicKey + +class PublicKeySerializer : CustomSerializer.Implements(PublicKey::class.java) { + override val additionalSerializers: Iterable> = emptyList() + + override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), SerializerFactory.primitiveTypeName(Binary::class.java)!!, descriptor, emptyList()))) + + override fun writeDescribedObject(obj: PublicKey, data: Data, type: Type, output: SerializationOutput) { + // TODO: Instead of encoding to the default X509 format, we could have a custom per key type (space-efficient) serialiser. + output.writeObject(Binary(obj.encoded), data, clazz) + } + + override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): PublicKey { + val A = input.readObject(obj, schema, ByteArray::class.java) as Binary + return Crypto.decodePublicKey(A.array) + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/ThrowableSerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/ThrowableSerializer.kt new file mode 100644 index 0000000000..ed267ed44d --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/ThrowableSerializer.kt @@ -0,0 +1,81 @@ +package net.corda.core.serialization.amqp.custom + +import net.corda.core.serialization.amqp.CustomSerializer +import net.corda.core.serialization.amqp.SerializerFactory +import net.corda.core.serialization.amqp.constructorForDeserialization +import net.corda.core.serialization.amqp.propertiesForSerialization +import net.corda.core.utilities.CordaRuntimeException +import net.corda.core.utilities.CordaThrowable +import java.io.NotSerializableException + +class ThrowableSerializer(factory: SerializerFactory) : CustomSerializer.Proxy(Throwable::class.java, ThrowableProxy::class.java, factory) { + override val additionalSerializers: Iterable> = listOf(StackTraceElementSerializer(factory)) + + override fun toProxy(obj: Throwable): ThrowableProxy { + val extraProperties: MutableMap = LinkedHashMap() + val message = if (obj is CordaThrowable) { + // Try and find a constructor + try { + val constructor = constructorForDeserialization(obj.javaClass) + val props = propertiesForSerialization(constructor, obj.javaClass, factory) + for (prop in props) { + extraProperties[prop.name] = prop.readMethod.invoke(obj) + } + } catch(e: NotSerializableException) { + } + obj.originalMessage + } else { + obj.message + } + return ThrowableProxy(obj.javaClass.name, message, obj.stackTrace, obj.cause, obj.suppressed, extraProperties) + } + + override fun fromProxy(proxy: ThrowableProxy): Throwable { + try { + // TODO: This will need reworking when we have multiple class loaders + val clazz = Class.forName(proxy.exceptionClass, false, this.javaClass.classLoader) + // If it is CordaException or CordaRuntimeException, we can seek any constructor and then set the properties + // Otherwise we just make a CordaRuntimeException + if (CordaThrowable::class.java.isAssignableFrom(clazz) && Throwable::class.java.isAssignableFrom(clazz)) { + val constructor = constructorForDeserialization(clazz)!! + val throwable = constructor.callBy(constructor.parameters.map { it to proxy.additionalProperties[it.name] }.toMap()) + (throwable as CordaThrowable).apply { + if (this.javaClass.name != proxy.exceptionClass) this.originalExceptionClassName = proxy.exceptionClass + this.setMessage(proxy.message) + this.setCause(proxy.cause) + this.addSuppressed(proxy.suppressed) + } + return (throwable as Throwable).apply { + this.stackTrace = proxy.stackTrace + } + } + } catch (e: Exception) { + // If attempts to rebuild the exact exception fail, we fall through and build a runtime exception. + } + // If the criteria are not met or we experience an exception constructing the exception, we fall back to our own unchecked exception. + return CordaRuntimeException(proxy.exceptionClass).apply { + this.setMessage(proxy.message) + this.setCause(proxy.cause) + this.stackTrace = proxy.stackTrace + this.addSuppressed(proxy.suppressed) + } + } + + class ThrowableProxy( + val exceptionClass: String, + val message: String?, + val stackTrace: Array, + val cause: Throwable?, + val suppressed: Array, + val additionalProperties: Map) +} + +class StackTraceElementSerializer(factory: SerializerFactory) : CustomSerializer.Proxy(StackTraceElement::class.java, StackTraceElementProxy::class.java, factory) { + override val additionalSerializers: Iterable> = emptyList() + + override fun toProxy(obj: StackTraceElement): StackTraceElementProxy = StackTraceElementProxy(obj.className, obj.methodName, obj.fileName, obj.lineNumber) + + override fun fromProxy(proxy: StackTraceElementProxy): StackTraceElement = StackTraceElement(proxy.declaringClass, proxy.methodName, proxy.fileName, proxy.lineNumber) + + data class StackTraceElementProxy(val declaringClass: String, val methodName: String, val fileName: String?, val lineNumber: Int) +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/utilities/CordaException.kt b/core/src/main/kotlin/net/corda/core/utilities/CordaException.kt new file mode 100644 index 0000000000..907bbee408 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/utilities/CordaException.kt @@ -0,0 +1,103 @@ +package net.corda.core.utilities + +import net.corda.core.serialization.CordaSerializable +import java.util.* + +@CordaSerializable +interface CordaThrowable { + var originalExceptionClassName: String? + val originalMessage: String? + fun setMessage(message: String?) + fun setCause(cause: Throwable?) + fun addSuppressed(suppressed: Array) +} + +open class CordaException internal constructor(override var originalExceptionClassName: String? = null, + private var _message: String? = null, + private var _cause: Throwable? = null) : Exception(null, null, true, true), CordaThrowable { + + constructor(message: String?, + cause: Throwable?) : this(null, message, cause) + + override val message: String? + get() = if (originalExceptionClassName == null) originalMessage else { + if (originalMessage == null) "$originalExceptionClassName" else "$originalExceptionClassName: $originalMessage" + } + + override val cause: Throwable? + get() = _cause ?: super.cause + + override fun setMessage(message: String?) { + _message = message + } + + override fun setCause(cause: Throwable?) { + _cause = cause + } + + override fun addSuppressed(suppressed: Array) { + for (suppress in suppressed) { + addSuppressed(suppress) + } + } + + override val originalMessage: String? + get() = _message + + override fun hashCode(): Int { + return Arrays.deepHashCode(stackTrace) xor Objects.hash(originalExceptionClassName, originalMessage) + } + + override fun equals(other: Any?): Boolean { + return other is CordaException && + originalExceptionClassName == other.originalExceptionClassName && + message == other.message && + cause == other.cause && + Arrays.equals(stackTrace, other.stackTrace) && + Arrays.equals(suppressed, other.suppressed) + } +} + +open class CordaRuntimeException internal constructor(override var originalExceptionClassName: String?, + private var _message: String? = null, + private var _cause: Throwable? = null) : RuntimeException(null, null, true, true), CordaThrowable { + constructor(message: String?, cause: Throwable?) : this(null, message, cause) + + override val message: String? + get() = if (originalExceptionClassName == null) originalMessage else { + if (originalMessage == null) "$originalExceptionClassName" else "$originalExceptionClassName: $originalMessage" + } + + override val cause: Throwable? + get() = _cause ?: super.cause + + override fun setMessage(message: String?) { + _message = message + } + + override fun setCause(cause: Throwable?) { + _cause = cause + } + + override fun addSuppressed(suppressed: Array) { + for (suppress in suppressed) { + addSuppressed(suppress) + } + } + + override val originalMessage: String? + get() = _message + + override fun hashCode(): Int { + return Arrays.deepHashCode(stackTrace) xor Objects.hash(originalExceptionClassName, originalMessage) + } + + override fun equals(other: Any?): Boolean { + return other is CordaRuntimeException && + originalExceptionClassName == other.originalExceptionClassName && + message == other.message && + cause == other.cause && + Arrays.equals(stackTrace, other.stackTrace) && + Arrays.equals(suppressed, other.suppressed) + } +} \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/serialization/amqp/SerializationOutputTests.kt b/core/src/test/kotlin/net/corda/core/serialization/amqp/SerializationOutputTests.kt index d76708db57..5896a3c292 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/amqp/SerializationOutputTests.kt +++ b/core/src/test/kotlin/net/corda/core/serialization/amqp/SerializationOutputTests.kt @@ -1,10 +1,14 @@ package net.corda.core.serialization.amqp +import net.corda.core.flows.FlowException import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.EmptyWhitelist +import net.corda.nodeapi.RPCException +import net.corda.testing.MEGA_CORP_PUBKEY import org.apache.qpid.proton.codec.DecoderImpl import org.apache.qpid.proton.codec.EncoderImpl import org.junit.Test +import java.io.IOException import java.io.NotSerializableException import java.nio.ByteBuffer import java.util.* @@ -74,7 +78,14 @@ class SerializationOutputTests { override fun hashCode(): Int = ginger } - private fun serdes(obj: Any, factory: SerializerFactory = SerializerFactory()): Any { + @CordaSerializable + interface AnnotatedInterface + + data class InheritAnnotation(val foo: String) : AnnotatedInterface + + data class PolymorphicProperty(val foo: FooInterface?) + + private fun serdes(obj: Any, factory: SerializerFactory = SerializerFactory(), freshDeserializationFactory: SerializerFactory = SerializerFactory(), expectedEqual: Boolean = true): Any { val ser = SerializationOutput(factory) val bytes = ser.serialize(obj) @@ -93,15 +104,16 @@ class SerializationOutputTests { val result = decoder.readObject() as Envelope assertNotNull(result) - val des = DeserializationInput() + val des = DeserializationInput(freshDeserializationFactory) val desObj = des.deserialize(bytes) - assertTrue(Objects.deepEquals(obj, desObj)) + assertTrue(Objects.deepEquals(obj, desObj) == expectedEqual) // Now repeat with a re-used factory val ser2 = SerializationOutput(factory) val des2 = DeserializationInput(factory) val desObj2 = des2.deserialize(ser2.serialize(obj)) - assertTrue(Objects.deepEquals(obj, desObj2)) + assertTrue(Objects.deepEquals(obj, desObj2) == expectedEqual) + assertTrue(Objects.deepEquals(desObj, desObj2)) // TODO: add some schema assertions to check correctly formed. return desObj2 @@ -230,4 +242,109 @@ class SerializationOutputTests { val obj = MismatchType(456) serdes(obj) } + + @Test + fun `test custom serializers on public key`() { + val factory = SerializerFactory() + factory.register(net.corda.core.serialization.amqp.custom.PublicKeySerializer()) + val factory2 = SerializerFactory() + factory2.register(net.corda.core.serialization.amqp.custom.PublicKeySerializer()) + val obj = MEGA_CORP_PUBKEY + serdes(obj, factory, factory2) + } + + @Test + fun `test annotation is inherited`() { + val obj = InheritAnnotation("blah") + serdes(obj, SerializerFactory(EmptyWhitelist)) + } + + @Test + fun `test throwables serialize`() { + val factory = SerializerFactory() + factory.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory)) + + val factory2 = SerializerFactory() + factory2.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory2)) + + val obj = IllegalAccessException("message").fillInStackTrace() + serdes(obj, factory, factory2, false) + } + + @Test + fun `test complex throwables serialize`() { + val factory = SerializerFactory() + factory.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory)) + + val factory2 = SerializerFactory() + factory2.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory2)) + + try { + try { + throw IOException("Layer 1") + } catch(t: Throwable) { + throw IllegalStateException("Layer 2", t) + } + } catch(t: Throwable) { + serdes(t, factory, factory2, false) + } + } + + @Test + fun `test suppressed throwables serialize`() { + val factory = SerializerFactory() + factory.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory)) + + val factory2 = SerializerFactory() + factory2.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory2)) + + try { + try { + throw IOException("Layer 1") + } catch(t: Throwable) { + val e = IllegalStateException("Layer 2") + e.addSuppressed(t) + throw e + } + } catch(t: Throwable) { + serdes(t, factory, factory2, false) + } + } + + @Test + fun `test flow corda exception subclasses serialize`() { + val factory = SerializerFactory() + factory.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory)) + + val factory2 = SerializerFactory() + factory2.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory2)) + + val obj = FlowException("message").fillInStackTrace() + serdes(obj, factory, factory2) + } + + @Test + fun `test RPC corda exception subclasses serialize`() { + val factory = SerializerFactory() + factory.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory)) + + val factory2 = SerializerFactory() + factory2.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory2)) + + val obj = RPCException("message").fillInStackTrace() + serdes(obj, factory, factory2) + } + + @Test + fun `test polymorphic property`() { + val obj = PolymorphicProperty(FooImplements("Ginger", 12)) + serdes(obj) + } + + @Test + fun `test null polymorphic property`() { + val obj = PolymorphicProperty(null) + serdes(obj) + } + } \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt b/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt index 16be6048b5..e5f279c692 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt @@ -9,6 +9,7 @@ import net.corda.core.requireExternal import net.corda.core.serialization.* import net.corda.core.toFuture import net.corda.core.toObservable +import net.corda.core.utilities.CordaRuntimeException import net.corda.nodeapi.config.OldConfig import rx.Observable import java.io.InputStream @@ -35,8 +36,7 @@ annotation class RPCSinceVersion(val version: Int) * Thrown to indicate a fatal error in the RPC system itself, as opposed to an error generated by the invoked * method. */ -@CordaSerializable -open class RPCException(msg: String, cause: Throwable?) : RuntimeException(msg, cause) { +open class RPCException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) { constructor(msg: String) : this(msg, null) } diff --git a/node/build.gradle b/node/build.gradle index 5676649bce..1de49eee07 100644 --- a/node/build.gradle +++ b/node/build.gradle @@ -83,7 +83,10 @@ dependencies { // Artemis: for reliable p2p message queues. compile "org.apache.activemq:artemis-server:${artemis_version}" compile "org.apache.activemq:artemis-core-client:${artemis_version}" - runtime "org.apache.activemq:artemis-amqp-protocol:${artemis_version}" + runtime ("org.apache.activemq:artemis-amqp-protocol:${artemis_version}") { + // Gains our proton-j version from core module. + exclude group: 'org.apache.qpid', module: 'proton-j' + } // JAnsi: for drawing things to the terminal in nicely coloured ways. compile "org.fusesource.jansi:jansi:$jansi_version"