diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt index d5668cc494..d1ed63327f 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt @@ -25,7 +25,7 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) { private val objectHistory: MutableList = mutableListOf() internal companion object { - val BYTES_NEEDED_TO_PEEK: Int = 23 + private val BYTES_NEEDED_TO_PEEK: Int = 23 fun peekSize(bytes: ByteArray): Int { // There's an 8 byte header, and then a 0 byte plus descriptor followed by constructor @@ -57,7 +57,6 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) { inline internal fun deserializeAndReturnEnvelope(bytes: SerializedBytes): ObjectAndEnvelope = deserializeAndReturnEnvelope(bytes, T::class.java) - @Throws(NotSerializableException::class) private fun getEnvelope(bytes: ByteSequence): Envelope { // Check that the lead bytes match expected header @@ -94,20 +93,16 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) { * be deserialized and a schema describing the types of the objects. */ @Throws(NotSerializableException::class) - fun deserialize(bytes: ByteSequence, clazz: Class): T { - return des { - val envelope = getEnvelope(bytes) - clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz)) - } + fun deserialize(bytes: ByteSequence, clazz: Class): T = des { + val envelope = getEnvelope(bytes) + clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz)) } @Throws(NotSerializableException::class) - internal fun deserializeAndReturnEnvelope(bytes: SerializedBytes, clazz: Class): ObjectAndEnvelope { - return des { - val envelope = getEnvelope(bytes) - // Now pick out the obj and schema from the envelope. - ObjectAndEnvelope(clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz)), envelope) - } + fun deserializeAndReturnEnvelope(bytes: SerializedBytes, clazz: Class): ObjectAndEnvelope = des { + val envelope = getEnvelope(bytes) + // Now pick out the obj and schema from the envelope. + ObjectAndEnvelope(clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz)), envelope) } internal fun readObjectOrNull(obj: Any?, schema: Schema, type: Type): Any? { @@ -115,36 +110,36 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) { } internal fun readObject(obj: Any, schema: Schema, type: Type): Any = - if (obj is DescribedType && ReferencedObject.DESCRIPTOR == obj.descriptor) { - // It must be a reference to an instance that has already been read, cheaply and quickly returning it by reference. - val objectIndex = (obj.described as UnsignedInteger).toInt() - if (objectIndex !in 0..objectHistory.size) - throw NotSerializableException("Retrieval of existing reference failed. Requested index $objectIndex " + - "is outside of the bounds for the list of size: ${objectHistory.size}") + if (obj is DescribedType && ReferencedObject.DESCRIPTOR == obj.descriptor) { + // It must be a reference to an instance that has already been read, cheaply and quickly returning it by reference. + val objectIndex = (obj.described as UnsignedInteger).toInt() + if (objectIndex !in 0..objectHistory.size) + throw NotSerializableException("Retrieval of existing reference failed. Requested index $objectIndex " + + "is outside of the bounds for the list of size: ${objectHistory.size}") - val objectRetrieved = objectHistory[objectIndex] - if (!objectRetrieved::class.java.isSubClassOf(type.asClass()!!)) - throw NotSerializableException("Existing reference type mismatch. Expected: '$type', found: '${objectRetrieved::class.java}'") - objectRetrieved - } - else { - val objectRead = when (obj) { - is DescribedType -> { - // Look up serializer in factory by descriptor - val serializer = serializerFactory.get(obj.descriptor, schema) - if (serializer.type != type && with(serializer.type) { !isSubClassOf(type) && !materiallyEquivalentTo(type) }) - throw NotSerializableException("Described type with descriptor ${obj.descriptor} was " + - "expected to be of type $type but was ${serializer.type}") - serializer.readObject(obj.described, schema, this) + val objectRetrieved = objectHistory[objectIndex] + if (!objectRetrieved::class.java.isSubClassOf(type.asClass()!!)) + throw NotSerializableException("Existing reference type mismatch. Expected: '$type', found: '${objectRetrieved::class.java}'") + objectRetrieved + } else { + val objectRead = when (obj) { + is DescribedType -> { + // Look up serializer in factory by descriptor + val serializer = serializerFactory.get(obj.descriptor, schema) + if (serializer.type != type && with(serializer.type) { !isSubClassOf(type) && !materiallyEquivalentTo(type) }) + throw NotSerializableException("Described type with descriptor ${obj.descriptor} was " + + "expected to be of type $type but was ${serializer.type}") + serializer.readObject(obj.described, schema, this) + } + is Binary -> obj.array + else -> obj // this will be the case for primitive types like [boolean] et al. } - is Binary -> obj.array - else -> obj // this will be the case for primitive types like [boolean] et al. + + // Store the reference in case we need it later on. + // Skip for primitive types as they are too small and overhead of referencing them will be much higher than their content + if (type.asClass()?.isPrimitive != true) objectHistory.add(objectRead) + objectRead } - // Store the reference in case we need it later on. - // Skip for primitive types as they are too small and overhead of referencing them will be much higher than their content - if (suitableForObjectReference(objectRead.javaClass)) objectHistory.add(objectRead) - objectRead - } /** * TODO: Currently performs rather basic checks aimed in particular at [java.util.List>] and @@ -152,5 +147,5 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) { * In the future tighter control might be needed */ private fun Type.materiallyEquivalentTo(that: Type): Boolean = - asClass() == that.asClass() && that is ParameterizedType -} \ No newline at end of file + asClass() == that.asClass() && that is ParameterizedType +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedParameterizedType.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedParameterizedType.kt index 3209866821..96b2d729d4 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedParameterizedType.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedParameterizedType.kt @@ -53,6 +53,7 @@ class DeserializedParameterizedType(private val rawType: Class<*>, private val p var typeStart = 0 var needAType = true var skippingWhitespace = false + while (pos < params.length) { if (params[pos] == '<') { val typeEnd = pos++ @@ -102,7 +103,7 @@ class DeserializedParameterizedType(private val rawType: Class<*>, private val p } else if (!skippingWhitespace && (params[pos] == '.' || params[pos].isJavaIdentifierPart())) { pos++ } else { - throw NotSerializableException("Invalid character in middle of type: ${params[pos]}") + throw NotSerializableException("Invalid character ${params[pos]} in middle of type $params at idx $pos") } } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumSerializer.kt new file mode 100644 index 0000000000..294a0c3cf0 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumSerializer.kt @@ -0,0 +1,51 @@ +package net.corda.nodeapi.internal.serialization.amqp + +import org.apache.qpid.proton.codec.Data +import java.lang.reflect.Type +import java.io.NotSerializableException + +/** + * Our definition of an enum with the AMQP spec is a list (of two items, a string and an int) that is + * a restricted type with a number of choices associated with it + */ +class EnumSerializer(declaredType: Type, declaredClass: Class<*>, factory: SerializerFactory) : AMQPSerializer { + override val type: Type = declaredType + override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}" + private val typeNotation: TypeNotation + + init { + typeNotation = RestrictedType( + SerializerFactory.nameForType(declaredType), + null, emptyList(), "list", Descriptor(typeDescriptor, null), + declaredClass.enumConstants.zip(IntRange(0, declaredClass.enumConstants.size)).map { + Choice(it.first.toString(), it.second.toString()) + }) + } + + override fun writeClassInfo(output: SerializationOutput) { + output.writeTypeNotations(typeNotation) + } + + override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any { + val enumName = (obj as List<*>)[0] as String + val enumOrd = obj[1] as Int + val fromOrd = type.asClass()!!.enumConstants[enumOrd] + + if (enumName != fromOrd?.toString()) { + throw NotSerializableException("Deserializing obj as enum $type with value $enumName.$enumOrd but " + + "ordinality has changed") + } + return fromOrd + } + + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) { + if (obj !is Enum<*>) throw NotSerializableException("Serializing $obj as enum when it isn't") + + data.withDescribed(typeNotation.descriptor) { + withList { + data.putString(obj.name) + data.putInt(obj.ordinal) + } + } + } +} \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/MapSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/MapSerializer.kt index 8b74459fe5..314c9b6aa9 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/MapSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/MapSerializer.kt @@ -13,7 +13,7 @@ import kotlin.collections.map /** * Serialization / deserialization of certain supported [Map] types. */ -class MapSerializer(val declaredType: ParameterizedType, factory: SerializerFactory) : AMQPSerializer { +class MapSerializer(private val declaredType: ParameterizedType, factory: SerializerFactory) : AMQPSerializer { override val type: Type = declaredType as? DeserializedParameterizedType ?: DeserializedParameterizedType.make(declaredType.toString()) override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}" diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/Schema.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/Schema.kt index 6a506685c6..2dfe80015a 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/Schema.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/Schema.kt @@ -191,12 +191,10 @@ sealed class TypeNotation : DescribedType { companion object { fun get(obj: Any): TypeNotation { val describedType = obj as DescribedType - if (describedType.descriptor == CompositeType.DESCRIPTOR) { - return CompositeType.get(describedType) - } else if (describedType.descriptor == RestrictedType.DESCRIPTOR) { - return RestrictedType.get(describedType) - } else { - throw NotSerializableException("Unexpected descriptor ${describedType.descriptor}.") + return when (describedType.descriptor) { + CompositeType.DESCRIPTOR -> CompositeType.get(describedType) + RestrictedType.DESCRIPTOR -> RestrictedType.get(describedType) + else -> throw NotSerializableException("Unexpected descriptor ${describedType.descriptor}.") } } } @@ -252,7 +250,12 @@ data class CompositeType(override val name: String, override val label: String?, } } -data class RestrictedType(override val name: String, override val label: String?, override val provides: List, val source: String, override val descriptor: Descriptor, val choices: List) : TypeNotation() { +data class RestrictedType(override val name: String, + override val label: String?, + override val provides: List, + val source: String, + override val descriptor: Descriptor, + val choices: List) : TypeNotation() { companion object : DescribedTypeConstructor { val DESCRIPTOR = DescriptorRegistry.RESTRICTED_TYPE.amqpDescriptor @@ -290,6 +293,9 @@ data class RestrictedType(override val name: String, override val label: String? } sb.append(">\n") sb.append(" $descriptor\n") + choices.forEach { + sb.append(" $it\n") + } sb.append("") return sb.toString() } @@ -352,6 +358,7 @@ data class ReferencedObject(private val refCounter: Int) : DescribedType { } private val ARRAY_HASH: String = "Array = true" +private val ENUM_HASH: String = "Enum = true" private val ALREADY_SEEN_HASH: String = "Already seen = true" private val NULLABLE_HASH: String = "Nullable = true" private val NOT_NULLABLE_HASH: String = "Nullable = false" @@ -382,7 +389,7 @@ internal fun fingerprintForDescriptors(vararg typeDescriptors: String): String { return hasher.hash().asBytes().toBase64() } -private fun Hasher.fingerprintWithCustomSerializerOrElse(factory: SerializerFactory, clazz: Class<*>, declaredType: Type, block: () -> Hasher) : Hasher { +private fun Hasher.fingerprintWithCustomSerializerOrElse(factory: SerializerFactory, clazz: Class<*>, declaredType: Type, block: () -> Hasher): Hasher { // Need to check if a custom serializer is applicable val customSerializer = factory.findCustomSerializer(clazz, declaredType) return if (customSerializer != null) { @@ -400,51 +407,58 @@ private fun fingerprintForType(type: Type, contextType: Type?, alreadySeen: Muta } else { alreadySeen += type try { - if (type is SerializerFactory.AnyType) { - hasher.putUnencodedChars(ANY_TYPE_HASH) - } else if (type is Class<*>) { - if (type.isArray) { - fingerprintForType(type.componentType, contextType, alreadySeen, hasher, factory).putUnencodedChars(ARRAY_HASH) - } else if (SerializerFactory.isPrimitive(type)) { - hasher.putUnencodedChars(type.name) - } else if (isCollectionOrMap(type)) { - hasher.putUnencodedChars(type.name) - } else { - hasher.fingerprintWithCustomSerializerOrElse(factory, type, type) { - if (type.kotlin.objectInstance != null) { - // TODO: name collision is too likely for kotlin objects, we need to introduce some reference - // to the CorDapp but maybe reference to the JAR in the short term. - hasher.putUnencodedChars(type.name) - } else { - fingerprintForObject(type, type, alreadySeen, hasher, factory) + when (type) { + is SerializerFactory.AnyType -> hasher.putUnencodedChars(ANY_TYPE_HASH) + is Class<*> -> { + if (type.isArray) { + fingerprintForType(type.componentType, contextType, alreadySeen, hasher, factory).putUnencodedChars(ARRAY_HASH) + } else if (SerializerFactory.isPrimitive(type)) { + hasher.putUnencodedChars(type.name) + } else if (isCollectionOrMap(type)) { + hasher.putUnencodedChars(type.name) + } else if (type.isEnum) { + // ensures any change to the enum (adding constants) will trigger the need for evolution + hasher.apply { + type.enumConstants.forEach { + putUnencodedChars(it.toString()) + } + }.putUnencodedChars(type.name).putUnencodedChars(ENUM_HASH) + } else { + hasher.fingerprintWithCustomSerializerOrElse(factory, type, type) { + if (type.kotlin.objectInstance != null) { + // TODO: name collision is too likely for kotlin objects, we need to introduce some reference + // to the CorDapp but maybe reference to the JAR in the short term. + hasher.putUnencodedChars(type.name) + } else { + fingerprintForObject(type, type, alreadySeen, hasher, factory) + } } } } - } else if (type is ParameterizedType) { - // Hash the rawType + params - val clazz = type.rawType as Class<*> - val startingHash = if (isCollectionOrMap(clazz)) { - hasher.putUnencodedChars(clazz.name) - } else { - hasher.fingerprintWithCustomSerializerOrElse(factory, clazz, type) { - fingerprintForObject(type, type, alreadySeen, hasher, factory) + is ParameterizedType -> { + // Hash the rawType + params + val clazz = type.rawType as Class<*> + val startingHash = if (isCollectionOrMap(clazz)) { + hasher.putUnencodedChars(clazz.name) + } else { + hasher.fingerprintWithCustomSerializerOrElse(factory, clazz, type) { + fingerprintForObject(type, type, alreadySeen, hasher, factory) + } + } + // ... and concatentate the type data for each parameter type. + type.actualTypeArguments.fold(startingHash) { orig, paramType -> + fingerprintForType(paramType, type, alreadySeen, orig, factory) } } - // ... and concatentate the type data for each parameter type. - type.actualTypeArguments.fold(startingHash) { orig, paramType -> fingerprintForType(paramType, type, alreadySeen, orig, factory) } - } else if (type is GenericArrayType) { - // Hash the element type + some array hash - fingerprintForType(type.genericComponentType, contextType, alreadySeen, hasher, factory).putUnencodedChars(ARRAY_HASH) - } else if (type is TypeVariable<*>) { - // TODO: include bounds - hasher.putUnencodedChars(type.name).putUnencodedChars(TYPE_VARIABLE_HASH) - } else if (type is WildcardType) { - hasher.putUnencodedChars(type.typeName).putUnencodedChars(WILDCARD_TYPE_HASH) + // Hash the element type + some array hash + is GenericArrayType -> fingerprintForType(type.genericComponentType, contextType, alreadySeen, + hasher, factory).putUnencodedChars(ARRAY_HASH) + // TODO: include bounds + is TypeVariable<*> -> hasher.putUnencodedChars(type.name).putUnencodedChars(TYPE_VARIABLE_HASH) + is WildcardType -> hasher.putUnencodedChars(type.typeName).putUnencodedChars(WILDCARD_TYPE_HASH) + else -> throw NotSerializableException("Don't know how to hash") } - else { - throw NotSerializableException("Don't know how to hash") - } - } catch(e: NotSerializableException) { + } catch (e: NotSerializableException) { val msg = "${e.message} -> $type" logger.error(msg, e) throw NotSerializableException(msg) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt index c0b8b19c3f..2d240c5329 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt @@ -16,12 +16,11 @@ import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.CopyOnWriteArrayList import javax.annotation.concurrent.ThreadSafe -data class schemaAndDescriptor (val schema: Schema, val typeDescriptor: Any) +data class schemaAndDescriptor(val schema: Schema, val typeDescriptor: Any) /** * Factory of serializers designed to be shared across threads and invocations. */ -// TODO: enums // TODO: object references - need better fingerprinting? // TODO: class references? (e.g. cheat with repeated descriptors using a long encoding, like object ref proposal) // TODO: Inner classes etc. Should we allow? Currently not considered. @@ -39,17 +38,20 @@ data class schemaAndDescriptor (val schema: Schema, val typeDescriptor: Any) // TODO: need to rethink matching of constructor to properties in relation to implementing interfaces and needing those properties etc. // TODO: need to support super classes as well as interfaces with our current code base... what's involved? If we continue to ban, what is the impact? @ThreadSafe -class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) { +class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) { private val serializersByType = ConcurrentHashMap>() private val serializersByDescriptor = ConcurrentHashMap>() private val customSerializers = CopyOnWriteArrayList>() private val classCarpenter = ClassCarpenter(cl) - val classloader : ClassLoader + val classloader: ClassLoader get() = classCarpenter.classloader - fun getEvolutionSerializer(typeNotation: TypeNotation, newSerializer: ObjectSerializer) : AMQPSerializer { + private fun getEvolutionSerializer(typeNotation: TypeNotation, newSerializer: AMQPSerializer): AMQPSerializer { return serializersByDescriptor.computeIfAbsent(typeNotation.descriptor.name!!) { - EvolutionSerializer.make(typeNotation as CompositeType, newSerializer, this) + when (typeNotation) { + is CompositeType -> EvolutionSerializer.make(typeNotation, newSerializer as ObjectSerializer, this) + is RestrictedType -> throw NotSerializableException("Enum evolution is not currently supported") + } } } @@ -66,18 +68,21 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) { val actualType: Type = inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType - val serializer = if (Collection::class.java.isAssignableFrom(declaredClass)) { - serializersByType.computeIfAbsent(declaredType) { - CollectionSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType( - declaredClass, arrayOf(AnyType), null), this) + val serializer = when { + (Collection::class.java.isAssignableFrom(declaredClass)) -> { + serializersByType.computeIfAbsent(declaredType) { + CollectionSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType( + declaredClass, arrayOf(AnyType), null), this) + } } - } else if (Map::class.java.isAssignableFrom(declaredClass)) { - serializersByType.computeIfAbsent(declaredClass) { + Map::class.java.isAssignableFrom(declaredClass) -> serializersByType.computeIfAbsent(declaredClass) { makeMapSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType( declaredClass, arrayOf(AnyType, AnyType), null)) } - } else { - makeClassSerializer(actualClass ?: declaredClass, actualType, declaredType) + Enum::class.java.isAssignableFrom(declaredClass) -> serializersByType.computeIfAbsent(declaredClass) { + EnumSerializer(actualType, actualClass ?: declaredClass, this) + } + else -> makeClassSerializer(actualClass ?: declaredClass, actualType, declaredType) } serializersByDescriptor.putIfAbsent(serializer.typeDescriptor, serializer) @@ -90,17 +95,17 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) { * type. */ // TODO: test GenericArrayType - private fun inferTypeVariables(actualClass: Class<*>?, declaredClass: Class<*>, declaredType: Type): Type? { - if (declaredType is ParameterizedType) { - return inferTypeVariables(actualClass, declaredClass, declaredType) - } else if (declaredType is Class<*>) { - // Nothing to infer, otherwise we'd have ParameterizedType - return actualClass - } else if (declaredType is GenericArrayType) { - val declaredComponent = declaredType.genericComponentType - return inferTypeVariables(actualClass?.componentType, declaredComponent.asClass()!!, declaredComponent)?.asArray() - } else return null - } + private fun inferTypeVariables(actualClass: Class<*>?, declaredClass: Class<*>, declaredType: Type): Type? = + when (declaredType) { + is ParameterizedType -> inferTypeVariables(actualClass, declaredClass, declaredType) + // Nothing to infer, otherwise we'd have ParameterizedType + is Class<*> -> actualClass + is GenericArrayType -> { + val declaredComponent = declaredType.genericComponentType + inferTypeVariables(actualClass?.componentType, declaredComponent.asClass()!!, declaredComponent)?.asArray() + } + else -> null + } /** * Try and infer concrete types for any generics type variables for the actual class encountered, based on the declared @@ -117,8 +122,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) { if (implementationChain != null) { val start = implementationChain.last() val rest = implementationChain.dropLast(1).drop(1) - val resolver = rest.reversed().fold(TypeResolver().where(start, declaredType)) { - resolved, chainEntry -> + val resolver = rest.reversed().fold(TypeResolver().where(start, declaredType)) { resolved, chainEntry -> val newResolved = resolved.resolveType(chainEntry) TypeResolver().where(chainEntry, newResolved) } @@ -194,7 +198,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) { // doesn't match that of the serialised object then we are dealing with different // instance of the class, as such we need to build an EvolutionSerialiser if (serialiser.typeDescriptor != typeNotation.descriptor.name) { - getEvolutionSerializer(typeNotation, serialiser as ObjectSerializer) + getEvolutionSerializer(typeNotation, serialiser) } } catch (e: ClassNotFoundException) { if (sentinel || (typeNotation !is CompositeType)) throw e @@ -210,16 +214,14 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) { } private fun processSchemaEntry(typeNotation: TypeNotation) = when (typeNotation) { - is CompositeType -> processCompositeType(typeNotation) // java.lang.Class (whether a class or interface) - is RestrictedType -> processRestrictedType(typeNotation) // Collection / Map, possibly with generics - } - - private fun processRestrictedType(typeNotation: RestrictedType): AMQPSerializer { - // TODO: class loader logic, and compare the schema. - val type = typeForName(typeNotation.name, classloader) - return get(null, type) + is CompositeType -> processCompositeType(typeNotation) // java.lang.Class (whether a class or interface) + is RestrictedType -> processRestrictedType(typeNotation) // Collection / Map, possibly with generics } + // TODO: class loader logic, and compare the schema. + private fun processRestrictedType(typeNotation: RestrictedType) = get(null, + typeForName(typeNotation.name, classloader)) + private fun processCompositeType(typeNotation: CompositeType): AMQPSerializer { // TODO: class loader logic, and compare the schema. val type = typeForName(typeNotation.name, classloader) @@ -233,7 +235,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) { findCustomSerializer(clazz, declaredType) ?: run { if (type.isArray()) { // Allow Object[] since this can be quite common (i.e. an untyped array) - if(type.componentType() != Object::class.java) whitelisted(type.componentType()) + if (type.componentType() != Object::class.java) whitelisted(type.componentType()) if (clazz.componentType.isPrimitive) PrimArraySerializer.make(type, this) else ArraySerializer.make(type, this) } else if (clazz.kotlin.objectInstance != null) { @@ -248,8 +250,9 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) { } internal fun findCustomSerializer(clazz: Class<*>, declaredType: Type): AMQPSerializer? { - // e.g. Imagine if we provided a Map serializer this way, then it won't work if the declared type is AbstractMap, only Map. - // Otherwise it needs to inject additional schema for a RestrictedType source of the super type. Could be done, but do we need it? + // e.g. Imagine if we provided a Map serializer this way, then it won't work if the declared type is + // AbstractMap, only Map. Otherwise it needs to inject additional schema for a RestrictedType source of the + // super type. Could be done, but do we need it? for (customSerializer in customSerializers) { if (customSerializer.isSerializerFor(clazz)) { val declaredSuperClass = declaredType.asClass()?.superclass @@ -258,7 +261,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) { } else { // Make a subclass serializer for the subclass and return that... @Suppress("UNCHECKED_CAST") - return CustomSerializer.SubClass(clazz, customSerializer as CustomSerializer) + return CustomSerializer.SubClass(clazz, customSerializer as CustomSerializer) } } } @@ -277,7 +280,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl : ClassLoader) { (!whitelist.hasListed(clazz) && !hasAnnotationInHierarchy(clazz)) // Recursively check the class, interfaces and superclasses for our annotation. - internal fun hasAnnotationInHierarchy(type: Class<*>): Boolean { + private fun hasAnnotationInHierarchy(type: Class<*>): Boolean { return type.isAnnotationPresent(CordaSerializable::class.java) || type.interfaces.any { hasAnnotationInHierarchy(it) } || (type.superclass != null && hasAnnotationInHierarchy(type.superclass)) diff --git a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaSerialiseEnumTests.java b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaSerialiseEnumTests.java new file mode 100644 index 0000000000..e0b65ad27c --- /dev/null +++ b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaSerialiseEnumTests.java @@ -0,0 +1,36 @@ +package net.corda.nodeapi.internal.serialization.amqp; + +import org.junit.Test; + +import net.corda.nodeapi.internal.serialization.AllWhitelist; +import net.corda.core.serialization.SerializedBytes; + +import java.io.NotSerializableException; + +public class JavaSerialiseEnumTests { + + public enum Bras { + TSHIRT, UNDERWIRE, PUSHUP, BRALETTE, STRAPLESS, SPORTS, BACKLESS, PADDED + } + + private static class Bra { + private final Bras bra; + + private Bra(Bras bra) { + this.bra = bra; + } + + public Bras getBra() { + return this.bra; + } + } + + @Test + public void testJavaConstructorAnnotations() throws NotSerializableException { + Bra bra = new Bra(Bras.UNDERWIRE); + + SerializerFactory factory1 = new SerializerFactory(AllWhitelist.INSTANCE, ClassLoader.getSystemClassLoader()); + SerializationOutput ser = new SerializationOutput(factory1); + SerializedBytes bytes = ser.serialize(bra); + } +} diff --git a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaSerializationOutputTests.java b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaSerializationOutputTests.java index 171b9cbe72..34ca1abaa6 100644 --- a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaSerializationOutputTests.java +++ b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/amqp/JavaSerializationOutputTests.java @@ -25,15 +25,17 @@ public class JavaSerializationOutputTests { } @ConstructorForDeserialization - public Foo(String fred, int count) { + private Foo(String fred, int count) { this.bob = fred; this.count = count; } + @SuppressWarnings("unused") public String getFred() { return bob; } + @SuppressWarnings("unused") public int getCount() { return count; } @@ -61,15 +63,17 @@ public class JavaSerializationOutputTests { private final String bob; private final int count; - public UnAnnotatedFoo(String fred, int count) { + private UnAnnotatedFoo(String fred, int count) { this.bob = fred; this.count = count; } + @SuppressWarnings("unused") public String getFred() { return bob; } + @SuppressWarnings("unused") public int getCount() { return count; } @@ -97,7 +101,7 @@ public class JavaSerializationOutputTests { private final String fred; private final Integer count; - public BoxedFoo(String fred, Integer count) { + private BoxedFoo(String fred, Integer count) { this.fred = fred; this.count = count; } @@ -134,7 +138,7 @@ public class JavaSerializationOutputTests { private final String fred; private final Integer count; - public BoxedFooNotNull(String fred, Integer count) { + private BoxedFooNotNull(String fred, Integer count) { this.fred = fred; this.count = count; } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPTestUtils.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPTestUtils.kt index 0494497138..4f9b7b6872 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPTestUtils.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPTestUtils.kt @@ -6,7 +6,6 @@ import net.corda.nodeapi.internal.serialization.AllWhitelist import net.corda.nodeapi.internal.serialization.EmptyWhitelist import java.io.NotSerializableException - fun testDefaultFactory() = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) fun testDefaultFactoryWithWhitelist() = SerializerFactory(EmptyWhitelist, ClassLoader.getSystemClassLoader()) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumTests.kt new file mode 100644 index 0000000000..d46ec93c5a --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/EnumTests.kt @@ -0,0 +1,189 @@ +package net.corda.nodeapi.internal.serialization.amqp + +import org.junit.Test +import java.time.DayOfWeek + +import kotlin.test.assertEquals +import kotlin.test.assertNotNull + +import java.io.File +import java.io.NotSerializableException + +import net.corda.core.serialization.SerializedBytes + +class EnumTests { + enum class Bras { + TSHIRT, UNDERWIRE, PUSHUP, BRALETTE, STRAPLESS, SPORTS, BACKLESS, PADDED + } + + // The state of the OldBras enum when the tests in changedEnum1 were serialised + // - use if the test file needs regenerating + //enum class OldBras { + // TSHIRT, UNDERWIRE, PUSHUP, BRALETTE + //} + + // the new state, SPACER has been added to change the ordinality + enum class OldBras { + SPACER, TSHIRT, UNDERWIRE, PUSHUP, BRALETTE + } + + // The state of the OldBras2 enum when the tests in changedEnum2 were serialised + // - use if the test file needs regenerating + //enum class OldBras2 { + // TSHIRT, UNDERWIRE, PUSHUP, BRALETTE + //} + + // the new state, note in the test we serialised with value UNDERWIRE so the spacer + // occuring after this won't have changed the ordinality of our serialised value + // and thus should still be deserialisable + enum class OldBras2 { + TSHIRT, UNDERWIRE, PUSHUP, SPACER, BRALETTE, SPACER2 + } + + + enum class BrasWithInit (val someList: List) { + TSHIRT(emptyList()), + UNDERWIRE(listOf(1, 2, 3)), + PUSHUP(listOf(100, 200)), + BRALETTE(emptyList()) + } + + private val brasTestName = "${this.javaClass.name}\$Bras" + + companion object { + /** + * If you want to see the schema encoded into the envelope after serialisation change this to true + */ + private const val VERBOSE = false + } + + @Suppress("NOTHING_TO_INLINE") + inline private fun classTestName(clazz: String) = "${this.javaClass.name}\$${testName()}\$$clazz" + + private val sf1 = testDefaultFactory() + + @Test + fun serialiseSimpleTest() { + data class C(val c: Bras) + + val schema = TestSerializationOutput(VERBOSE, sf1).serializeAndReturnSchema(C(Bras.UNDERWIRE)).schema + + assertEquals(2, schema.types.size) + val schema_c = schema.types.find { it.name == classTestName("C") } as CompositeType + val schema_bras = schema.types.find { it.name == brasTestName } as RestrictedType + + assertNotNull(schema_c) + assertNotNull(schema_bras) + + assertEquals(1, schema_c.fields.size) + assertEquals("c", schema_c.fields.first().name) + assertEquals(brasTestName, schema_c.fields.first().type) + + assertEquals(8, schema_bras.choices.size) + Bras.values().forEach { + val bra = it + assertNotNull (schema_bras.choices.find { it.name == bra.name }) + } + } + + @Test + fun deserialiseSimpleTest() { + data class C(val c: Bras) + + val objAndEnvelope = DeserializationInput(sf1).deserializeAndReturnEnvelope( + TestSerializationOutput(VERBOSE, sf1).serialize(C(Bras.UNDERWIRE))) + + val obj = objAndEnvelope.obj + val schema = objAndEnvelope.envelope.schema + + assertEquals(2, schema.types.size) + val schema_c = schema.types.find { it.name == classTestName("C") } as CompositeType + val schema_bras = schema.types.find { it.name == brasTestName } as RestrictedType + + assertEquals(1, schema_c.fields.size) + assertEquals("c", schema_c.fields.first().name) + assertEquals(brasTestName, schema_c.fields.first().type) + + assertEquals(8, schema_bras.choices.size) + Bras.values().forEach { + val bra = it + assertNotNull (schema_bras.choices.find { it.name == bra.name }) + } + + // Test the actual deserialised object + assertEquals(obj.c, Bras.UNDERWIRE) + } + + @Test + fun multiEnum() { + data class Support (val top: Bras, val day : DayOfWeek) + data class WeeklySupport (val tops: List) + + val week = WeeklySupport (listOf( + Support (Bras.PUSHUP, DayOfWeek.MONDAY), + Support (Bras.UNDERWIRE, DayOfWeek.WEDNESDAY), + Support (Bras.PADDED, DayOfWeek.SUNDAY))) + + val obj = DeserializationInput(sf1).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(week)) + + assertEquals(week.tops[0].top, obj.tops[0].top) + assertEquals(week.tops[0].day, obj.tops[0].day) + assertEquals(week.tops[1].top, obj.tops[1].top) + assertEquals(week.tops[1].day, obj.tops[1].day) + assertEquals(week.tops[2].top, obj.tops[2].top) + assertEquals(week.tops[2].day, obj.tops[2].day) + } + + @Test + fun enumWithInit() { + data class C(val c: BrasWithInit) + + val c = C (BrasWithInit.PUSHUP) + val obj = DeserializationInput(sf1).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(c)) + + assertEquals(c.c, obj.c) + } + + @Test(expected = NotSerializableException::class) + fun changedEnum1() { + val path = EnumTests::class.java.getResource("EnumTests.changedEnum1") + val f = File(path.toURI()) + + data class C (val a: OldBras) + + // Original version of the class for the serialised version of this class + // + // val a = OldBras.TSHIRT + // val sc = SerializationOutput(sf1).serialize(C(a)) + // f.writeBytes(sc.bytes) + // println(path) + + val sc2 = f.readBytes() + + // we expect this to throw + DeserializationInput(sf1).deserialize(SerializedBytes(sc2)) + } + + @Test(expected = NotSerializableException::class) + fun changedEnum2() { + val path = EnumTests::class.java.getResource("EnumTests.changedEnum2") + val f = File(path.toURI()) + + data class C (val a: OldBras2) + + // DO NOT CHANGE THIS, it's important we serialise with a value that doesn't + // change position in the upated enum class + + // Original version of the class for the serialised version of this class + // + // val a = OldBras2.UNDERWIRE + // val sc = SerializationOutput(sf1).serialize(C(a)) + // f.writeBytes(sc.bytes) + // println(path) + + val sc2 = f.readBytes() + + // we expect this to throw + DeserializationInput(sf1).deserialize(SerializedBytes(sc2)) + } +} \ No newline at end of file diff --git a/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/EnumTests.changedEnum1 b/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/EnumTests.changedEnum1 new file mode 100644 index 0000000000..800c35bf17 Binary files /dev/null and b/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/EnumTests.changedEnum1 differ diff --git a/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/EnumTests.changedEnum2 b/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/EnumTests.changedEnum2 new file mode 100644 index 0000000000..5f42911837 Binary files /dev/null and b/node-api/src/test/resources/net/corda/nodeapi/internal/serialization/amqp/EnumTests.changedEnum2 differ