diff --git a/client/jackson/src/main/kotlin/net/corda/client/jackson/internal/CordaModule.kt b/client/jackson/src/main/kotlin/net/corda/client/jackson/internal/CordaModule.kt index 2a2a11041d..388b9c3205 100644 --- a/client/jackson/src/main/kotlin/net/corda/client/jackson/internal/CordaModule.kt +++ b/client/jackson/src/main/kotlin/net/corda/client/jackson/internal/CordaModule.kt @@ -3,7 +3,7 @@ package net.corda.client.jackson.internal import com.fasterxml.jackson.annotation.* -import com.fasterxml.jackson.annotation.JsonCreator.Mode.DISABLED +import com.fasterxml.jackson.annotation.JsonCreator.Mode.* import com.fasterxml.jackson.annotation.JsonInclude.Include import com.fasterxml.jackson.core.JsonGenerator import com.fasterxml.jackson.core.JsonParseException @@ -38,10 +38,8 @@ import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.parseAsHex import net.corda.core.utilities.toHexString import net.corda.serialization.internal.AllWhitelist -import net.corda.serialization.internal.amqp.SerializerFactoryBuilder -import net.corda.serialization.internal.amqp.constructorForDeserialization -import net.corda.serialization.internal.amqp.hasCordaSerializable -import net.corda.serialization.internal.amqp.propertiesForSerialization +import net.corda.serialization.internal.amqp.* +import net.corda.serialization.internal.model.LocalTypeInformation import java.math.BigDecimal import java.security.PublicKey import java.security.cert.CertPath @@ -95,10 +93,11 @@ private class CordaSerializableBeanSerializerModifier : BeanSerializerModifier() beanProperties: MutableList): MutableList { val beanClass = beanDesc.beanClass if (hasCordaSerializable(beanClass) && beanClass.kotlinObjectInstance == null) { - val ctor = constructorForDeserialization(beanClass) - val amqpProperties = propertiesForSerialization(ctor, beanClass, serializerFactory) - .serializationOrder - .mapNotNull { if (it.isCalculated) null else it.serializer.name } + val typeInformation = serializerFactory.getTypeInformation(beanClass) + val properties = typeInformation.propertiesOrEmptyMap + val amqpProperties = properties.mapNotNull { (name, property) -> + if (property.isCalculated) null else name + } val propertyRenames = beanDesc.findProperties().associateBy({ it.name }, { it.internalName }) (amqpProperties - propertyRenames.values).let { check(it.isEmpty()) { "Jackson didn't provide serialisers for $it" } diff --git a/finance/src/test/kotlin/net/corda/finance/compat/CompatibilityTest.kt b/finance/src/test/kotlin/net/corda/finance/compat/CompatibilityTest.kt index bf1dc2cb88..b329fd1a17 100644 --- a/finance/src/test/kotlin/net/corda/finance/compat/CompatibilityTest.kt +++ b/finance/src/test/kotlin/net/corda/finance/compat/CompatibilityTest.kt @@ -1,16 +1,22 @@ package net.corda.finance.compat import net.corda.core.serialization.SerializationDefaults -import net.corda.core.serialization.deserialize -import net.corda.core.serialization.serialize +import net.corda.core.serialization.SerializedBytes import net.corda.core.transactions.SignedTransaction import net.corda.finance.contracts.asset.Cash +import net.corda.serialization.internal.AllWhitelist +import net.corda.serialization.internal.amqp.DeserializationInput +import net.corda.serialization.internal.amqp.Schema +import net.corda.serialization.internal.amqp.SerializationOutput +import net.corda.serialization.internal.amqp.SerializerFactoryBuilder +import net.corda.serialization.internal.amqp.custom.PublicKeySerializer import net.corda.testing.core.SerializationEnvironmentRule import org.junit.Rule import org.junit.Test import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue +import kotlin.test.fail // TODO: If this type of testing gets momentum, we can create a mini-framework that rides through list of files // and performs necessary validation on all of them. @@ -20,19 +26,63 @@ class CompatibilityTest { @JvmField val testSerialization = SerializationEnvironmentRule() + val serializerFactory = SerializerFactoryBuilder.build(AllWhitelist, ClassLoader.getSystemClassLoader()).apply { + register(PublicKeySerializer) + } + @Test fun issueCashTansactionReadTest() { val inputStream = javaClass.classLoader.getResourceAsStream("compatibilityData/v3/node_transaction.dat") assertNotNull(inputStream) + val inByteArray: ByteArray = inputStream.readBytes() - val transaction = inByteArray.deserialize(context = SerializationDefaults.STORAGE_CONTEXT) + val input = DeserializationInput(serializerFactory) + + val (transaction, envelope) = input.deserializeAndReturnEnvelope( + SerializedBytes(inByteArray), + SignedTransaction::class.java, + SerializationDefaults.STORAGE_CONTEXT) assertNotNull(transaction) + val commands = transaction.tx.commands assertEquals(1, commands.size) assertTrue(commands.first().value is Cash.Commands.Issue) // Serialize back and check that representation is byte-to-byte identical to what it was originally. - val serializedForm = transaction.serialize(context = SerializationDefaults.STORAGE_CONTEXT) - assertTrue(inByteArray.contentEquals(serializedForm.bytes)) + val output = SerializationOutput(serializerFactory) + val (serializedBytes, schema) = output.serializeAndReturnSchema(transaction, SerializationDefaults.STORAGE_CONTEXT) + + assertSchemasMatch(envelope.schema, schema) + + assertTrue(inByteArray.contentEquals(serializedBytes.bytes)) + } + + private fun assertSchemasMatch(original: Schema, reserialized: Schema) { + if (original.toString() == reserialized.toString()) return + original.types.forEach { originalType -> + val reserializedType = reserialized.types.firstOrNull { it.name == originalType.name } ?: + fail("""Schema mismatch between original and re-serialized data. Could not find reserialized schema matching: + +$originalType +""") + + if (originalType.toString() != reserializedType.toString()) + fail("""Schema mismatch between original and re-serialized data. Expected: + +$originalType + +but was: + +$reserializedType +""") + } + + reserialized.types.forEach { reserializedType -> + if (original.types.none { it.name == reserializedType.name }) + fail("""Schema mismatch between original and re-serialized data. Could not find original schema matching: + +$reserializedType +""") + } } } \ No newline at end of file diff --git a/serialization-deterministic/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPSerializerFactories.kt b/serialization-deterministic/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPSerializerFactories.kt index d0bb4b4798..2b1d66ab76 100644 --- a/serialization-deterministic/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPSerializerFactories.kt +++ b/serialization-deterministic/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPSerializerFactories.kt @@ -16,8 +16,8 @@ fun createSerializerFactoryFactory(): SerializerFactoryFactory = DeterministicSe private class DeterministicSerializerFactoryFactory : SerializerFactoryFactory { override fun make(context: SerializationContext) = SerializerFactoryBuilder.build( - whitelist = context.whitelist, - classCarpenter = DummyClassCarpenter(context.whitelist, context.deserializationClassLoader)) + whitelist = context.whitelist, + classCarpenter = DummyClassCarpenter(context.whitelist, context.deserializationClassLoader)) } private class DummyClassCarpenter( diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPPrimitiveSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPPrimitiveSerializer.kt index adccbe0bc7..5342f683a4 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPPrimitiveSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPPrimitiveSerializer.kt @@ -12,7 +12,7 @@ import java.lang.reflect.Type * [ByteArray] is automatically marshalled to/from the Proton-J wrapper, [Binary]. */ class AMQPPrimitiveSerializer(clazz: Class<*>) : AMQPSerializer { - override val typeDescriptor = Symbol.valueOf(SerializerFactory.primitiveTypeName(clazz)!!)!! + override val typeDescriptor = Symbol.valueOf(AMQPTypeIdentifiers.primitiveTypeName(clazz))!! override val type: Type = clazz // NOOP since this is a primitive type. diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPRemoteTypeModel.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPRemoteTypeModel.kt index f717f87f3b..6865ef4de4 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPRemoteTypeModel.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/AMQPRemoteTypeModel.kt @@ -3,6 +3,7 @@ package net.corda.serialization.internal.amqp import net.corda.serialization.internal.model.* import java.io.NotSerializableException import java.util.* +import kotlin.collections.LinkedHashMap /** * Interprets AMQP [Schema] information to obtain [RemoteTypeInformation], caching by [TypeDescriptor]. @@ -35,9 +36,17 @@ class AMQPRemoteTypeModel { val interpretationState = InterpretationState(notationLookup, enumTransformsLookup, cache, emptySet()) - return byTypeDescriptor.mapValues { (typeDescriptor, typeNotation) -> + val result = byTypeDescriptor.mapValues { (typeDescriptor, typeNotation) -> cache.getOrPut(typeDescriptor) { interpretationState.run { typeNotation.name.typeIdentifier.interpretIdentifier() } } } + val typesByIdentifier = result.values.associateBy { it.typeIdentifier } + result.values.forEach { typeInformation -> + if (typeInformation is RemoteTypeInformation.Cycle) { + typeInformation.follow = typesByIdentifier[typeInformation.typeIdentifier] ?: + throw NotSerializableException("Cannot resolve cyclic reference to ${typeInformation.typeIdentifier}") + } + } + return result } data class InterpretationState(val notationLookup: Map, @@ -45,9 +54,6 @@ class AMQPRemoteTypeModel { val cache: MutableMap, val seen: Set) { - private inline fun forgetSeen(block: InterpretationState.() -> T): T = - withSeen(emptySet(), block) - private inline fun withSeen(typeIdentifier: TypeIdentifier, block: InterpretationState.() -> T): T = withSeen(seen + typeIdentifier, block) @@ -62,7 +68,7 @@ class AMQPRemoteTypeModel { * know we have hit a cycle and respond accordingly. */ fun TypeIdentifier.interpretIdentifier(): RemoteTypeInformation = - if (this in seen) RemoteTypeInformation.Cycle(this) { forgetSeen { interpretIdentifier() } } + if (this in seen) RemoteTypeInformation.Cycle(this) else withSeen(this) { val identifier = this@interpretIdentifier notationLookup[identifier]?.interpretNotation(identifier) ?: interpretNoNotation() @@ -85,7 +91,7 @@ class AMQPRemoteTypeModel { * [RemoteTypeInformation]. */ private fun CompositeType.interpretComposite(identifier: TypeIdentifier): RemoteTypeInformation { - val properties = fields.asSequence().map { it.interpret() }.toMap() + val properties = fields.asSequence().sortedBy { it.name }.map { it.interpret() }.toMap(LinkedHashMap()) val typeParameters = identifier.interpretTypeParameters() val interfaceIdentifiers = provides.map { name -> name.typeIdentifier } val isInterface = identifier in interfaceIdentifiers @@ -175,6 +181,11 @@ class AMQPRemoteTypeModel { } } +fun LocalTypeInformation.getEnumTransforms(factory: LocalSerializerFactory): EnumTransforms { + val transformsSchema = TransformsSchema.get(typeIdentifier.name, factory) + return interpretTransformSet(transformsSchema) +} + private fun interpretTransformSet(transformSet: EnumMap>): EnumTransforms { val defaultTransforms = transformSet[TransformTypes.EnumDefault]?.toList() ?: emptyList() val defaults = defaultTransforms.associate { transform -> (transform as EnumDefaultSchemaTransform).new to transform.old } @@ -185,7 +196,7 @@ private fun interpretTransformSet(transformSet: EnumMap { +open class ArraySerializer(override val type: Type, factory: LocalSerializerFactory) : AMQPSerializer { companion object { - fun make(type: Type, factory: SerializerFactory) : AMQPSerializer { + fun make(type: Type, factory: LocalSerializerFactory) : AMQPSerializer { contextLogger().debug { "Making array serializer, typename=${type.typeName}" } return when (type) { Array::class.java -> CharArraySerializer(factory) @@ -41,8 +41,8 @@ open class ArraySerializer(override val type: Type, factory: SerializerFactory) // Special case handler for primitive byte arrays. This is needed because we can silently // coerce a byte[] to our own binary type. Normally, if the component type was itself an // array we'd keep walking down the chain but for byte[] stop here and use binary instead - val typeName = if (SerializerFactory.isPrimitive(type.componentType())) { - SerializerFactory.nameForType(type.componentType()) + val typeName = if (AMQPTypeIdentifiers.isPrimitive(type.componentType())) { + AMQPTypeIdentifiers.nameForType(type.componentType()) } else { calcTypeName(type.componentType(), debugOffset + 4) } @@ -55,7 +55,7 @@ open class ArraySerializer(override val type: Type, factory: SerializerFactory) } override val typeDescriptor: Symbol by lazy { - Symbol.valueOf("$DESCRIPTOR_DOMAIN:${factory.fingerPrinter.fingerprint(type)}") + factory.createDescriptor(type) } internal val elementType: Type by lazy { type.componentType() } @@ -103,7 +103,7 @@ open class ArraySerializer(override val type: Type, factory: SerializerFactory) // Boxed Character arrays required a specialisation to handle the type conversion properly when populating // the array since Kotlin won't allow an implicit cast from Int (as they're stored as 16bit ints) to Char -class CharArraySerializer(factory: SerializerFactory) : ArraySerializer(Array::class.java, factory) { +class CharArraySerializer(factory: LocalSerializerFactory) : ArraySerializer(Array::class.java, factory) { override fun List.toArrayOfType(type: Type): Any { val elementType = type.asClass() val list = this @@ -114,11 +114,11 @@ class CharArraySerializer(factory: SerializerFactory) : ArraySerializer(Array PrimArraySerializer> = mapOf( + private val primTypes: Map PrimArraySerializer> = mapOf( IntArray::class.java to { f -> PrimIntArraySerializer(f) }, CharArray::class.java to { f -> PrimCharArraySerializer(f) }, BooleanArray::class.java to { f -> PrimBooleanArraySerializer(f) }, @@ -129,7 +129,7 @@ abstract class PrimArraySerializer(type: Type, factory: SerializerFactory) : Arr // ByteArray::class.java <-> NOT NEEDED HERE (see comment above) ) - fun make(type: Type, factory: SerializerFactory) = primTypes[type]!!(factory) + fun make(type: Type, factory: LocalSerializerFactory) = primTypes[type]!!(factory) } fun localWriteObject(data: Data, func: () -> Unit) { @@ -137,7 +137,7 @@ abstract class PrimArraySerializer(type: Type, factory: SerializerFactory) : Arr } } -class PrimIntArraySerializer(factory: SerializerFactory) : PrimArraySerializer(IntArray::class.java, factory) { +class PrimIntArraySerializer(factory: LocalSerializerFactory) : PrimArraySerializer(IntArray::class.java, factory) { override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext, debugIndent: Int ) { @@ -147,7 +147,7 @@ class PrimIntArraySerializer(factory: SerializerFactory) : PrimArraySerializer(I } } -class PrimCharArraySerializer(factory: SerializerFactory) : PrimArraySerializer(CharArray::class.java, factory) { +class PrimCharArraySerializer(factory: LocalSerializerFactory) : PrimArraySerializer(CharArray::class.java, factory) { override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext, debugIndent: Int ) { @@ -168,7 +168,7 @@ class PrimCharArraySerializer(factory: SerializerFactory) : PrimArraySerializer( } } -class PrimBooleanArraySerializer(factory: SerializerFactory) : PrimArraySerializer(BooleanArray::class.java, factory) { +class PrimBooleanArraySerializer(factory: LocalSerializerFactory) : PrimArraySerializer(BooleanArray::class.java, factory) { override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext, debugIndent: Int ) { @@ -178,7 +178,7 @@ class PrimBooleanArraySerializer(factory: SerializerFactory) : PrimArraySerializ } } -class PrimDoubleArraySerializer(factory: SerializerFactory) : +class PrimDoubleArraySerializer(factory: LocalSerializerFactory) : PrimArraySerializer(DoubleArray::class.java, factory) { override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext, debugIndent: Int @@ -189,7 +189,7 @@ class PrimDoubleArraySerializer(factory: SerializerFactory) : } } -class PrimFloatArraySerializer(factory: SerializerFactory) : +class PrimFloatArraySerializer(factory: LocalSerializerFactory) : PrimArraySerializer(FloatArray::class.java, factory) { override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext, debugIndent: Int) { @@ -199,7 +199,7 @@ class PrimFloatArraySerializer(factory: SerializerFactory) : } } -class PrimShortArraySerializer(factory: SerializerFactory) : +class PrimShortArraySerializer(factory: LocalSerializerFactory) : PrimArraySerializer(ShortArray::class.java, factory) { override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext, debugIndent: Int @@ -210,7 +210,7 @@ class PrimShortArraySerializer(factory: SerializerFactory) : } } -class PrimLongArraySerializer(factory: SerializerFactory) : +class PrimLongArraySerializer(factory: LocalSerializerFactory) : PrimArraySerializer(LongArray::class.java, factory) { override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext, debugIndent: Int diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CollectionSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CollectionSerializer.kt index 7df6abd9ed..29953e840a 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CollectionSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CollectionSerializer.kt @@ -1,11 +1,13 @@ package net.corda.serialization.internal.amqp import net.corda.core.KeepForDJVM -import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.SerializationContext import net.corda.core.utilities.NonEmptySet +import net.corda.serialization.internal.model.LocalTypeInformation +import net.corda.serialization.internal.model.TypeIdentifier import org.apache.qpid.proton.amqp.Symbol import org.apache.qpid.proton.codec.Data +import java.io.NotSerializableException import java.lang.reflect.ParameterizedType import java.lang.reflect.Type import java.util.* @@ -15,11 +17,11 @@ import kotlin.collections.LinkedHashSet * Serialization / deserialization of predefined set of supported [Collection] types covering mostly [List]s and [Set]s. */ @KeepForDJVM -class CollectionSerializer(private val declaredType: ParameterizedType, factory: SerializerFactory) : AMQPSerializer { - override val type: Type = declaredType as? DeserializedParameterizedType - ?: DeserializedParameterizedType.make(SerializerFactory.nameForType(declaredType)) +class CollectionSerializer(private val declaredType: ParameterizedType, factory: LocalSerializerFactory) : AMQPSerializer { + override val type: Type = declaredType + override val typeDescriptor: Symbol by lazy { - Symbol.valueOf("$DESCRIPTOR_DOMAIN:${factory.fingerPrinter.fingerprint(type)}") + factory.createDescriptor(type) } companion object { @@ -33,40 +35,60 @@ class CollectionSerializer(private val declaredType: ParameterizedType, factory: NonEmptySet::class.java to { list -> NonEmptySet.copyOf(list) } )) + private val supportedTypeIdentifiers = supportedTypes.keys.asSequence().map { TypeIdentifier.forClass(it) }.toSet() + + /** + * Replace erased collection types with parameterised types with wildcard type parameters, so that they are represented + * appropriately in the AMQP schema. + */ + fun resolveDeclared(declaredTypeInformation: LocalTypeInformation.ACollection): LocalTypeInformation.ACollection { + if (declaredTypeInformation.typeIdentifier.erased in supportedTypeIdentifiers) + return reparameterise(declaredTypeInformation) + + throw NotSerializableException( + "Cannot derive collection type for declared type: " + + declaredTypeInformation.prettyPrint(false)) + } + + fun resolveActual(actualClass: Class<*>, declaredTypeInformation: LocalTypeInformation.ACollection): LocalTypeInformation.ACollection { + if (declaredTypeInformation.typeIdentifier.erased in supportedTypeIdentifiers) + return reparameterise(declaredTypeInformation) + + val collectionClass = findMostSuitableCollectionType(actualClass) + val erasedInformation = LocalTypeInformation.ACollection( + collectionClass, + TypeIdentifier.forClass(collectionClass), + LocalTypeInformation.Unknown) + + return when(declaredTypeInformation.typeIdentifier) { + is TypeIdentifier.Parameterised -> erasedInformation.withElementType(declaredTypeInformation.elementType) + else -> erasedInformation.withElementType(LocalTypeInformation.Unknown) + } + } + + private fun reparameterise(typeInformation: LocalTypeInformation.ACollection): LocalTypeInformation.ACollection = + when(typeInformation.typeIdentifier) { + is TypeIdentifier.Parameterised -> typeInformation + is TypeIdentifier.Erased -> typeInformation.withElementType(LocalTypeInformation.Unknown) + else -> throw NotSerializableException( + "Unexpected type identifier ${typeInformation.typeIdentifier.prettyPrint(false)} " + + "for collection type ${typeInformation.prettyPrint(false)}") + } + + private fun findMostSuitableCollectionType(actualClass: Class<*>): Class> = + supportedTypes.keys.findLast { it.isAssignableFrom(actualClass) }!! + private fun findConcreteType(clazz: Class<*>): (List<*>) -> Collection<*> { return supportedTypes[clazz] ?: throw AMQPNotSerializableException( clazz, "Unsupported collection type $clazz.", "Supported Collections are ${supportedTypes.keys.joinToString(",")}") } - - fun deriveParameterizedType(declaredType: Type, declaredClass: Class<*>, actualClass: Class<*>?): ParameterizedType { - if (supportedTypes.containsKey(declaredClass)) { - // Simple case - it is already known to be a collection. - return deriveParametrizedType(declaredType, uncheckedCast(declaredClass)) - } else if (actualClass != null && Collection::class.java.isAssignableFrom(actualClass)) { - // Declared class is not collection, but [actualClass] is - represent it accordingly. - val collectionClass = findMostSuitableCollectionType(actualClass) - return deriveParametrizedType(declaredType, collectionClass) - } - - throw AMQPNotSerializableException( - declaredType, - "Cannot derive collection type for declaredType: '$declaredType', " + - "declaredClass: '$declaredClass', actualClass: '$actualClass'") - } - - private fun deriveParametrizedType(declaredType: Type, collectionClass: Class>): ParameterizedType = - (declaredType as? ParameterizedType) - ?: DeserializedParameterizedType(collectionClass, arrayOf(SerializerFactory.AnyType)) - - private fun findMostSuitableCollectionType(actualClass: Class<*>): Class> = - supportedTypes.keys.findLast { it.isAssignableFrom(actualClass) }!! } private val concreteBuilder: (List<*>) -> Collection<*> = findConcreteType(declaredType.rawType as Class<*>) - private val typeNotation: TypeNotation = RestrictedType(SerializerFactory.nameForType(declaredType), null, emptyList(), "list", Descriptor(typeDescriptor), emptyList()) + private val typeNotation: TypeNotation = RestrictedType(AMQPTypeIdentifiers.nameForType(declaredType), null, emptyList(), "list", Descriptor(typeDescriptor), emptyList()) private val outboundType = resolveTypeVariables(declaredType.actualTypeArguments[0], null) private val inboundType = declaredType.actualTypeArguments[0] diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ComposableTypePropertySerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ComposableTypePropertySerializer.kt new file mode 100644 index 0000000000..aa9764fd1e --- /dev/null +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ComposableTypePropertySerializer.kt @@ -0,0 +1,270 @@ +package net.corda.serialization.internal.amqp + +import net.corda.core.serialization.SerializationContext +import net.corda.serialization.internal.model.* +import org.apache.qpid.proton.amqp.Binary +import org.apache.qpid.proton.codec.Data +import java.lang.reflect.Method +import java.lang.reflect.Field +import java.lang.reflect.Type + +/** + * A strategy for reading a property value during deserialization. + */ +interface PropertyReadStrategy { + + companion object { + /** + * Select the correct strategy for reading properties, based on the property type. + */ + fun make(name: String, typeIdentifier: TypeIdentifier, type: Type): PropertyReadStrategy = + if (AMQPTypeIdentifiers.isPrimitive(typeIdentifier)) { + when (typeIdentifier) { + in characterTypes -> AMQPCharPropertyReadStrategy + else -> AMQPPropertyReadStrategy + } + } else { + DescribedTypeReadStrategy(name, typeIdentifier, type) + } + } + + /** + * Use this strategy to read the value of a property during deserialization. + */ + fun readProperty(obj: Any?, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any? + +} + +/** + * A strategy for writing a property value during serialisation. + */ +interface PropertyWriteStrategy { + + companion object { + /** + * Select the correct strategy for writing properties, based on the property information. + */ + fun make(name: String, propertyInformation: LocalPropertyInformation, factory: LocalSerializerFactory): PropertyWriteStrategy { + val reader = PropertyReader.make(propertyInformation) + val type = propertyInformation.type + return if (AMQPTypeIdentifiers.isPrimitive(type.typeIdentifier)) { + when (type.typeIdentifier) { + in characterTypes -> AMQPCharPropertyWriteStategy(reader) + else -> AMQPPropertyWriteStrategy(reader) + } + } else { + DescribedTypeWriteStrategy(name, propertyInformation, reader) { factory.get(propertyInformation.type) } + } + } + } + + /** + * Write any [TypeNotation] needed to the [SerializationOutput]. + */ + fun writeClassInfo(output: SerializationOutput) + + /** + * Write the property's value to the [SerializationOutput]. + */ + fun writeProperty(obj: Any?, data: Data, output: SerializationOutput, context: SerializationContext, debugIndent: Int) +} + +/** + * Combines strategies for reading and writing a given property's value during serialisation/deserialisation. + */ +interface PropertySerializer : PropertyReadStrategy, PropertyWriteStrategy { + /** + * The name of the property. + */ + val name: String + /** + * Whether the property is calculated. + */ + val isCalculated: Boolean +} + +/** + * A [PropertySerializer] for a property of a [LocalTypeInformation.Composable] type. + */ +class ComposableTypePropertySerializer( + override val name: String, + override val isCalculated: Boolean, + private val readStrategy: PropertyReadStrategy, + private val writeStrategy: PropertyWriteStrategy) : + PropertySerializer, + PropertyReadStrategy by readStrategy, + PropertyWriteStrategy by writeStrategy { + + companion object { + /** + * Make a [PropertySerializer] for the given [LocalPropertyInformation]. + * + * @param name The name of the property. + * @param propertyInformation [LocalPropertyInformation] for the property. + * @param factory The [LocalSerializerFactory] to use when writing values for this property. + */ + fun make(name: String, propertyInformation: LocalPropertyInformation, factory: LocalSerializerFactory): PropertySerializer = + ComposableTypePropertySerializer( + name, + propertyInformation.isCalculated, + PropertyReadStrategy.make(name, propertyInformation.type.typeIdentifier, propertyInformation.type.observedType), + PropertyWriteStrategy.make(name, propertyInformation, factory)) + + /** + * Make a [PropertySerializer] for use in deserialization only, when deserializing a type that requires evolution. + * + * @param name The name of the property. + * @param isCalculated Whether the property is calculated. + * @param typeIdentifier The [TypeIdentifier] for the property type. + * @param type The local [Type] for the property type. + */ + fun makeForEvolution(name: String, isCalculated: Boolean, typeIdentifier: TypeIdentifier, type: Type): PropertySerializer = + ComposableTypePropertySerializer( + name, + isCalculated, + PropertyReadStrategy.make(name, typeIdentifier, type), + EvolutionPropertyWriteStrategy) + } +} + +/** + * Obtains the value of a property from an instance of the type to which that property belongs, either by calling a getter method + * or by reading the value of a private backing field. + */ +sealed class PropertyReader { + + companion object { + /** + * Make a [PropertyReader] based on the provided [LocalPropertyInformation]. + */ + fun make(propertyInformation: LocalPropertyInformation) = when(propertyInformation) { + is LocalPropertyInformation.GetterSetterProperty -> GetterReader(propertyInformation.observedGetter) + is LocalPropertyInformation.ConstructorPairedProperty -> GetterReader(propertyInformation.observedGetter) + is LocalPropertyInformation.ReadOnlyProperty -> GetterReader(propertyInformation.observedGetter) + is LocalPropertyInformation.CalculatedProperty -> GetterReader(propertyInformation.observedGetter) + is LocalPropertyInformation.PrivateConstructorPairedProperty -> FieldReader(propertyInformation.observedField) + } + } + + /** + * Get the value of the property from the supplied instance, or null if the instance is itself null. + */ + abstract fun read(obj: Any?): Any? + + /** + * Reads a property using a getter [Method]. + */ + class GetterReader(private val getter: Method): PropertyReader() { + init { + getter.isAccessible = true + } + + override fun read(obj: Any?): Any? = if (obj == null) null else getter.invoke(obj) + } + + /** + * Reads a property using a backing [Field]. + */ + class FieldReader(private val field: Field): PropertyReader() { + init { + field.isAccessible = true + } + + override fun read(obj: Any?): Any? = if (obj == null) null else field.get(obj) + } +} + +private val characterTypes = setOf( + TypeIdentifier.forClass(Char::class.javaObjectType), + TypeIdentifier.forClass(Char::class.javaPrimitiveType!!) +) + +object EvolutionPropertyWriteStrategy : PropertyWriteStrategy { + override fun writeClassInfo(output: SerializationOutput) = + throw UnsupportedOperationException("Evolution serializers cannot write values") + + override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput, context: SerializationContext, debugIndent: Int) = + throw UnsupportedOperationException("Evolution serializers cannot write values") +} + +/** + * Read a type that comes with its own [TypeDescriptor], by calling back into [RemoteSerializerFactory] to obtain a suitable + * serializer for that descriptor. + */ +class DescribedTypeReadStrategy(name: String, + typeIdentifier: TypeIdentifier, + private val type: Type): PropertyReadStrategy { + + private val nameForDebug = "$name(${typeIdentifier.prettyPrint(false)})" + + override fun readProperty(obj: Any?, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any? = + ifThrowsAppend({ nameForDebug }) { + input.readObjectOrNull(obj, schemas, type, context) + } +} + +/** + * Writes a property value into [SerializationOutput], together with a schema information describing it. + */ +class DescribedTypeWriteStrategy(private val name: String, + private val propertyInformation: LocalPropertyInformation, + private val reader: PropertyReader, + private val serializerProvider: () -> AMQPSerializer) : PropertyWriteStrategy { + + // Lazy to avoid getting into infinite loops when there are cycles. + private val serializer by lazy { serializerProvider() } + + private val nameForDebug get() = "$name(${propertyInformation.type.typeIdentifier.prettyPrint(false)})" + + override fun writeClassInfo(output: SerializationOutput) { + if (propertyInformation.type !is LocalTypeInformation.Top) { + serializer.writeClassInfo(output) + } + } + + override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput, context: SerializationContext, + debugIndent: Int) = ifThrowsAppend({ nameForDebug }) { + val propertyValue = reader.read(obj) + output.writeObjectOrNull(propertyValue, data, propertyInformation.type.observedType, context, debugIndent) + } +} + +object AMQPPropertyReadStrategy : PropertyReadStrategy { + override fun readProperty(obj: Any?, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any? = + if (obj is Binary) obj.array else obj +} + +class AMQPPropertyWriteStrategy(private val reader: PropertyReader) : PropertyWriteStrategy { + override fun writeClassInfo(output: SerializationOutput) {} + + override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput, + context: SerializationContext, debugIndent: Int + ) { + val value = reader.read(obj) + // ByteArrays have to be wrapped in an AMQP Binary wrapper. + if (value is ByteArray) { + data.putObject(Binary(value)) + } else { + data.putObject(value) + } + } +} + +object AMQPCharPropertyReadStrategy : PropertyReadStrategy { + override fun readProperty(obj: Any?, schemas: SerializationSchemas, + input: DeserializationInput, context: SerializationContext + ): Any? { + return if (obj == null) null else (obj as Short).toChar() + } +} + +class AMQPCharPropertyWriteStategy(private val reader: PropertyReader) : PropertyWriteStrategy { + override fun writeClassInfo(output: SerializationOutput) {} + + override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput, + context: SerializationContext, debugIndent: Int + ) { + val input = reader.read(obj) + if (input != null) data.putShort((input as Char).toShort()) else data.putNull() + } +} \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CorDappCustomSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CorDappCustomSerializer.kt index f1b509c940..b614fc7ac7 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CorDappCustomSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CorDappCustomSerializer.kt @@ -4,7 +4,6 @@ import com.google.common.reflect.TypeToken import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationCustomSerializer -import net.corda.serialization.internal.amqp.SerializerFactory.Companion.nameForType import org.apache.qpid.proton.amqp.Symbol import org.apache.qpid.proton.codec.Data import java.lang.reflect.Type @@ -63,9 +62,11 @@ class CorDappCustomSerializer( override val type = types[CORDAPP_TYPE] val proxyType = types[PROXY_TYPE] - override val typeDescriptor: Symbol = Symbol.valueOf("$DESCRIPTOR_DOMAIN:${nameForType(type)}") + override val typeDescriptor: Symbol = Symbol.valueOf("$DESCRIPTOR_DOMAIN:${AMQPTypeIdentifiers.nameForType(type)}") val descriptor: Descriptor = Descriptor(typeDescriptor) - private val proxySerializer: ObjectSerializer by lazy { ObjectSerializer(proxyType, factory) } + private val proxySerializer: ObjectSerializer by lazy { + ObjectSerializer.make(factory.getTypeInformation(proxyType), factory) + } override fun writeClassInfo(output: SerializationOutput) {} @@ -77,8 +78,8 @@ class CorDappCustomSerializer( data.withDescribed(descriptor) { data.withList { - proxySerializer.propertySerializers.serializationOrder.forEach { - it.serializer.writeProperty(proxy, this, output, context) + (proxySerializer as ObjectSerializer).propertySerializers.forEach { (_, serializer) -> + serializer.writeProperty(proxy, this, output, context, debugIndent) } } } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializer.kt index cbd54f08c2..8e44e3ab56 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializer.kt @@ -3,7 +3,7 @@ package net.corda.serialization.internal.amqp import net.corda.core.KeepForDJVM import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.SerializationContext -import net.corda.serialization.internal.amqp.SerializerFactory.Companion.nameForType +import net.corda.serialization.internal.model.FingerprintWriter import org.apache.qpid.proton.amqp.Symbol import org.apache.qpid.proton.codec.Data import java.lang.reflect.Type @@ -67,13 +67,13 @@ abstract class CustomSerializer : AMQPSerializer, SerializerFor { override fun isSerializerFor(clazz: Class<*>): Boolean = clazz == this.clazz override val type: Type get() = clazz override val typeDescriptor: Symbol by lazy { - Symbol.valueOf("$DESCRIPTOR_DOMAIN:${fingerprintForDescriptors(superClassSerializer.typeDescriptor.toString(), nameForType(clazz))}") + Symbol.valueOf("$DESCRIPTOR_DOMAIN:${FingerprintWriter(false).write(arrayOf(superClassSerializer.typeDescriptor.toString(), AMQPTypeIdentifiers.nameForType(clazz)).joinToString()).fingerprint}") } private val typeNotation: TypeNotation = RestrictedType( - SerializerFactory.nameForType(clazz), + AMQPTypeIdentifiers.nameForType(clazz), null, emptyList(), - SerializerFactory.nameForType(superClassSerializer.type), + AMQPTypeIdentifiers.nameForType(superClassSerializer.type), Descriptor(typeDescriptor), emptyList()) @@ -102,7 +102,7 @@ abstract class CustomSerializer : AMQPSerializer, SerializerFor { */ abstract class CustomSerializerImp(protected val clazz: Class, protected val withInheritance: Boolean) : CustomSerializer() { override val type: Type get() = clazz - override val typeDescriptor: Symbol = Symbol.valueOf("$DESCRIPTOR_DOMAIN:${nameForType(clazz)}") + override val typeDescriptor: Symbol = Symbol.valueOf("$DESCRIPTOR_DOMAIN:${AMQPTypeIdentifiers.nameForType(clazz)}") override fun writeClassInfo(output: SerializationOutput) {} override val descriptor: Descriptor = Descriptor(typeDescriptor) override fun isSerializerFor(clazz: Class<*>): Boolean = if (withInheritance) this.clazz.isAssignableFrom(clazz) else this.clazz == clazz @@ -127,19 +127,19 @@ abstract class CustomSerializer : AMQPSerializer, SerializerFor { */ abstract class Proxy(clazz: Class, protected val proxyClass: Class

, - protected val factory: SerializerFactory, + protected val factory: LocalSerializerFactory, withInheritance: Boolean = true) : CustomSerializerImp(clazz, withInheritance) { override fun isSerializerFor(clazz: Class<*>): Boolean = if (withInheritance) this.clazz.isAssignableFrom(clazz) else this.clazz == clazz - private val proxySerializer: ObjectSerializer by lazy { ObjectSerializer(proxyClass, factory) } + private val proxySerializer: ObjectSerializer by lazy { ObjectSerializer.make(factory.getTypeInformation(proxyClass), factory) } override val schemaForDocumentation: Schema by lazy { val typeNotations = mutableSetOf( CompositeType( - nameForType(type), + AMQPTypeIdentifiers.nameForType(type), null, emptyList(), - descriptor, (proxySerializer.typeNotation as CompositeType).fields)) + descriptor, proxySerializer.fields)) for (additional in additionalSerializers) { typeNotations.addAll(additional.schemaForDocumentation.types) } @@ -158,8 +158,8 @@ abstract class CustomSerializer : AMQPSerializer, SerializerFor { ) { val proxy = toProxy(obj) data.withList { - proxySerializer.propertySerializers.serializationOrder.forEach { - it.serializer.writeProperty(proxy, this, output, context) + proxySerializer.propertySerializers.forEach { (_, serializer) -> + serializer.writeProperty(proxy, this, output, context, 0) } } } @@ -191,8 +191,8 @@ abstract class CustomSerializer : AMQPSerializer, SerializerFor { : CustomSerializerImp(clazz, withInheritance) { override val schemaForDocumentation = Schema( - listOf(RestrictedType(nameForType(type), "", listOf(nameForType(type)), - SerializerFactory.primitiveTypeName(String::class.java)!!, + listOf(RestrictedType(AMQPTypeIdentifiers.nameForType(type), "", listOf(AMQPTypeIdentifiers.nameForType(type)), + AMQPTypeIdentifiers.primitiveTypeName(String::class.java), descriptor, emptyList()))) override fun writeDescribedObject(obj: T, data: Data, type: Type, output: SerializationOutput, diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DescriptorBasedSerializerRegistry.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DescriptorBasedSerializerRegistry.kt index 2a2d17127d..8adc48fbed 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DescriptorBasedSerializerRegistry.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DescriptorBasedSerializerRegistry.kt @@ -25,5 +25,5 @@ class DefaultDescriptorBasedSerializerRegistry: DescriptorBasedSerializerRegistr } override fun getOrBuild(descriptor: String, builder: () -> AMQPSerializer) = - get(descriptor) ?: builder().also { newSerializer -> this[descriptor] = newSerializer } + registry.getOrPut(descriptor) { builder() } } \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt index 7bc06c5a24..f5ed5dd924 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt @@ -8,6 +8,7 @@ import net.corda.core.serialization.SerializedBytes import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.loggerFor import net.corda.serialization.internal.* +import net.corda.serialization.internal.model.TypeIdentifier import org.apache.qpid.proton.amqp.Binary import org.apache.qpid.proton.amqp.DescribedType import org.apache.qpid.proton.amqp.UnsignedInteger @@ -168,8 +169,8 @@ class DeserializationInput constructor( val objectRead = when (obj) { is DescribedType -> { // Look up serializer in factory by descriptor - val serializer = serializerFactory.get(obj.descriptor, schemas) - if (SerializerFactory.AnyType != type && serializer.type != type && with(serializer.type) { + val serializer = serializerFactory.get(obj.descriptor.toString(), schemas) + if (type != TypeIdentifier.UnknownType.getLocalType() && serializer.type != type && with(serializer.type) { !isSubClassOf(type) && !materiallyEquivalentTo(type) } ) { diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializedGenericArrayType.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializedGenericArrayType.kt deleted file mode 100644 index 364b5afa6e..0000000000 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializedGenericArrayType.kt +++ /dev/null @@ -1,18 +0,0 @@ -package net.corda.serialization.internal.amqp - -import java.lang.reflect.GenericArrayType -import java.lang.reflect.Type -import java.util.* - -/** - * Implementation of [GenericArrayType] that we can actually construct. - */ -class DeserializedGenericArrayType(private val componentType: Type) : GenericArrayType { - override fun getGenericComponentType(): Type = componentType - override fun getTypeName(): String = "${componentType.typeName}[]" - override fun toString(): String = typeName - override fun hashCode(): Int = Objects.hashCode(componentType) - override fun equals(other: Any?): Boolean { - return other is GenericArrayType && (componentType == other.genericComponentType) - } -} diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializedParameterizedType.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializedParameterizedType.kt deleted file mode 100644 index f6321e1dc3..0000000000 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializedParameterizedType.kt +++ /dev/null @@ -1,174 +0,0 @@ -package net.corda.serialization.internal.amqp - -import com.google.common.primitives.Primitives -import net.corda.core.KeepForDJVM -import java.io.NotSerializableException -import java.lang.reflect.ParameterizedType -import java.lang.reflect.Type -import java.lang.reflect.TypeVariable -import java.util.* - -/** - * Implementation of [ParameterizedType] that we can actually construct, and a parser from the string representation - * of the JDK implementation which we use as the textual format in the AMQP schema. - */ -@KeepForDJVM -class DeserializedParameterizedType( - private val rawType: Class<*>, - private val params: Array, - private val ownerType: Type? = null -) : ParameterizedType { - init { - if (params.isEmpty()) { - throw AMQPNotSerializableException(rawType, "Must be at least one parameter type in a ParameterizedType") - } - if (params.size != rawType.typeParameters.size) { - throw AMQPNotSerializableException( - rawType, - "Expected ${rawType.typeParameters.size} for ${rawType.name} but found ${params.size}") - } - } - - private fun boundedType(type: TypeVariable>): Boolean { - return !(type.bounds.size == 1 && type.bounds[0] == Object::class.java) - } - - private val _typeName: String = makeTypeName() - - private fun makeTypeName(): String { - val paramsJoined = params.joinToString(", ") { it.typeName } - return "${rawType.name}<$paramsJoined>" - } - - companion object { - // Maximum depth/nesting of generics before we suspect some DoS attempt. - const val MAX_DEPTH: Int = 32 - - fun make(name: String, cl: ClassLoader = DeserializedParameterizedType::class.java.classLoader): Type { - val paramTypes = ArrayList() - val pos = parseTypeList("$name>", paramTypes, cl) - if (pos <= name.length) { - throw AMQPNoTypeNotSerializableException( - "Malformed string form of ParameterizedType. Unexpected '>' at character position $pos of $name.") - } - if (paramTypes.size != 1) { - throw AMQPNoTypeNotSerializableException("Expected only one type, but got $paramTypes") - } - return paramTypes[0] - } - - private fun parseTypeList(params: String, types: MutableList, cl: ClassLoader, depth: Int = 0): Int { - var pos = 0 - var typeStart = 0 - var needAType = true - var skippingWhitespace = false - - while (pos < params.length) { - if (params[pos] == '<') { - val typeEnd = pos++ - val paramTypes = ArrayList() - pos = parseTypeParams(params, pos, paramTypes, cl, depth + 1) - types += makeParameterizedType(params.substring(typeStart, typeEnd).trim(), paramTypes, cl) - typeStart = pos - needAType = false - } else if (params[pos] == ',') { - val typeEnd = pos++ - val typeName = params.substring(typeStart, typeEnd).trim() - if (!typeName.isEmpty()) { - types += makeType(typeName, cl) - } else if (needAType) { - throw AMQPNoTypeNotSerializableException("Expected a type, not ','") - } - typeStart = pos - needAType = true - } else if (params[pos] == '>') { - val typeEnd = pos++ - val typeName = params.substring(typeStart, typeEnd).trim() - if (!typeName.isEmpty()) { - types += makeType(typeName, cl) - } else if (needAType) { - throw AMQPNoTypeNotSerializableException("Expected a type, not '>'") - } - return pos - } else { - // Skip forwards, checking character types - if (pos == typeStart) { - skippingWhitespace = false - if (params[pos].isWhitespace()) { - typeStart = ++pos - } else if (!needAType) { - throw AMQPNoTypeNotSerializableException("Not expecting a type") - } else if (params[pos] == '?') { - pos++ - } else if (!params[pos].isJavaIdentifierStart()) { - throw AMQPNoTypeNotSerializableException("Invalid character at start of type: ${params[pos]}") - } else { - pos++ - } - } else { - if (params[pos].isWhitespace()) { - pos++ - skippingWhitespace = true - } else if (!skippingWhitespace && (params[pos] == '.' || params[pos].isJavaIdentifierPart())) { - pos++ - } else { - throw AMQPNoTypeNotSerializableException( - "Invalid character ${params[pos]} in middle of type $params at idx $pos") - } - } - } - } - throw AMQPNoTypeNotSerializableException("Missing close generics '>'") - } - - private fun makeType(typeName: String, cl: ClassLoader): Type { - // Not generic - return if (typeName == "?") SerializerFactory.AnyType else { - Primitives.wrap(SerializerFactory.primitiveType(typeName) ?: Class.forName(typeName, false, cl)) - } - } - - private fun makeParameterizedType(rawTypeName: String, args: MutableList, cl: ClassLoader): Type { - return DeserializedParameterizedType(makeType(rawTypeName, cl) as Class<*>, args.toTypedArray(), null) - } - - private fun parseTypeParams( - params: String, - startPos: Int, - paramTypes: MutableList, - cl: ClassLoader, - depth: Int - ): Int { - if (depth == MAX_DEPTH) { - throw AMQPNoTypeNotSerializableException("Maximum depth of nested generics reached: $depth") - } - return startPos + parseTypeList(params.substring(startPos), paramTypes, cl, depth) - } - } - - override fun getRawType(): Type = rawType - - override fun getOwnerType(): Type? = ownerType - - override fun getActualTypeArguments(): Array = params - - override fun getTypeName(): String = _typeName - - override fun toString(): String = _typeName - - override fun hashCode(): Int { - return Arrays.hashCode(this.actualTypeArguments) xor Objects.hashCode(this.ownerType) xor Objects.hashCode(this.rawType) - } - - override fun equals(other: Any?): Boolean { - return if (other is ParameterizedType) { - if (this === other) { - true - } else { - this.ownerType == other.ownerType && this.rawType == other.rawType && Arrays.equals(this.actualTypeArguments, other.actualTypeArguments) - } - } else { - false - } - } -} \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumEvolutionSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumEvolutionSerializer.kt index 5e7010c71c..ec9ef9e678 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumEvolutionSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumEvolutionSerializer.kt @@ -2,6 +2,7 @@ package net.corda.serialization.internal.amqp import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.SerializationContext +import net.corda.serialization.internal.model.LocalTypeInformation import org.apache.qpid.proton.amqp.Symbol import org.apache.qpid.proton.codec.Data import java.io.NotSerializableException @@ -37,100 +38,20 @@ import java.util.* */ class EnumEvolutionSerializer( override val type: Type, - factory: SerializerFactory, + factory: LocalSerializerFactory, private val conversions: Map, private val ordinals: Map) : AMQPSerializer { - override val typeDescriptor = Symbol.valueOf( - "$DESCRIPTOR_DOMAIN:${factory.fingerPrinter.fingerprint(type)}")!! - - companion object { - private fun MutableMap.mapInPlace(f: (String) -> String) { - val i = iterator() - while (i.hasNext()) { - val curr = i.next() - curr.setValue(f(curr.value)) - } - } - - /** - * Builds an Enum Evolver serializer. - * - * @param old The description of the enum as it existed at the time of serialisation taken from the - * received AMQP header - * @param new The Serializer object we built based on the current state of the enum class on our classpath - * @param factory the [SerializerFactory] that is building this serialization object. - * @param schemas the transforms attached to the class in the AMQP header, i.e. the transforms - * known at serialization time - */ - fun make(old: RestrictedType, - new: AMQPSerializer, - factory: SerializerFactory, - schemas: SerializationSchemas): AMQPSerializer { - val wireTransforms = schemas.transforms.types[old.name] - ?: EnumMap>(TransformTypes::class.java) - val localTransforms = TransformsSchema.get(old.name, factory) - - // remember, the longer the list the newer we're assuming the transform set it as we assume - // evolution annotations are never removed, only added to - val transforms = if (wireTransforms.size > localTransforms.size) wireTransforms else localTransforms - - // if either of these isn't of the cast type then something has gone terribly wrong - // elsewhere in the code - val defaultRules: List? = uncheckedCast(transforms[TransformTypes.EnumDefault]) - val renameRules: List? = uncheckedCast(transforms[TransformTypes.Rename]) - - // What values exist on the enum as it exists on the class path - val localValues = new.type.asClass().enumConstants.map { it.toString() } - - val conversions: MutableMap = localValues - .union(defaultRules?.map { it.new }?.toSet() ?: emptySet()) - .union(renameRules?.map { it.to } ?: emptySet()) - .associateBy({ it }, { it }) - .toMutableMap() - - val rules: MutableMap = mutableMapOf() - rules.putAll(defaultRules?.associateBy({ it.new }, { it.old }) ?: emptyMap()) - val renameRulesMap = renameRules?.associateBy({ it.to }, { it.from }) ?: emptyMap() - rules.putAll(renameRulesMap) - - // take out set of all possible constants and build a map from those to the - // existing constants applying the rename and defaulting rules as defined - // in the schema - while (conversions.filterNot { it.value in localValues }.isNotEmpty()) { - conversions.mapInPlace { rules[it] ?: it } - } - - // you'd think this was overkill to get access to the ordinal values for each constant but it's actually - // rather tricky when you don't have access to the actual type, so this is a nice way to be able - // to precompute and pass to the actual object - val ordinals = localValues.mapIndexed { i, s -> Pair(s, i) }.toMap() - - // create a mapping between the ordinal value and the name as it was serialised converted - // to the name as it exists. We want to test any new constants have been added to the end - // of the enum class - val serialisedOrds = ((schemas.schema.types.find { it.name == old.name } as RestrictedType).choices - .associateBy({ it.value.toInt() }, { conversions[it.name] })) - - if (ordinals.filterNot { serialisedOrds[it.value] == it.key }.isNotEmpty()) { - throw AMQPNotSerializableException( - new.type, - "Constants have been reordered, additions must be appended to the end") - } - - return EnumEvolutionSerializer(new.type, factory, conversions, ordinals) - } - } + override val typeDescriptor = factory.createDescriptor(type) override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext ): Any { val enumName = (obj as List<*>)[0] as String - if (enumName !in conversions) { - throw AMQPNotSerializableException(type, "No rule to evolve enum constant $type::$enumName") - } + val converted = conversions[enumName] ?: throw AMQPNotSerializableException(type, "No rule to evolve enum constant $type::$enumName") + val ordinal = ordinals[converted] ?: throw AMQPNotSerializableException(type, "Ordinal not found for enum value $type::$converted") - return type.asClass().enumConstants[ordinals[conversions[enumName]]!!] + return type.asClass().enumConstants[ordinal] } override fun writeClassInfo(output: SerializationOutput) { diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumSerializer.kt index 1bb12190f2..da8b922649 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EnumSerializer.kt @@ -1,24 +1,21 @@ package net.corda.serialization.internal.amqp import net.corda.core.serialization.SerializationContext -import org.apache.qpid.proton.amqp.Symbol import org.apache.qpid.proton.codec.Data -import java.io.NotSerializableException import java.lang.reflect.Type /** * Our definition of an enum with the AMQP spec is a list (of two items, a string and an int) that is * a restricted type with a number of choices associated with it */ -class EnumSerializer(declaredType: Type, declaredClass: Class<*>, factory: SerializerFactory) : AMQPSerializer { +class EnumSerializer(declaredType: Type, declaredClass: Class<*>, factory: LocalSerializerFactory) : AMQPSerializer { override val type: Type = declaredType private val typeNotation: TypeNotation - override val typeDescriptor = Symbol.valueOf( - "$DESCRIPTOR_DOMAIN:${factory.fingerPrinter.fingerprint(type)}")!! + override val typeDescriptor = factory.createDescriptor(type) init { typeNotation = RestrictedType( - SerializerFactory.nameForType(declaredType), + AMQPTypeIdentifiers.nameForType(declaredType), null, emptyList(), "list", Descriptor(typeDescriptor), declaredClass.enumConstants.zip(IntRange(0, declaredClass.enumConstants.size)).map { Choice(it.first.toString(), it.second.toString()) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializer.kt deleted file mode 100644 index 700a3b51bb..0000000000 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializer.kt +++ /dev/null @@ -1,312 +0,0 @@ -package net.corda.serialization.internal.amqp - -import net.corda.core.KeepForDJVM -import net.corda.core.internal.isConcreteClass -import net.corda.core.serialization.DeprecatedConstructorForDeserialization -import net.corda.core.serialization.SerializationContext -import net.corda.core.utilities.contextLogger -import net.corda.core.utilities.debug -import net.corda.core.utilities.loggerFor -import net.corda.serialization.internal.carpenter.getTypeAsClass -import org.apache.qpid.proton.codec.Data -import java.io.NotSerializableException -import java.lang.reflect.Type -import kotlin.reflect.KFunction -import kotlin.reflect.full.findAnnotation -import kotlin.reflect.jvm.javaType -import kotlin.reflect.jvm.jvmErasure - - -/** - * Serializer for deserializing objects whose definition has changed since they - * were serialised. - * - * @property oldReaders A linked map representing the properties of the object as they were serialized. Note - * this may contain properties that are no longer needed by the class. These *must* be read however to ensure - * any refferenced objects in the object stream are captured properly - * @property kotlinConstructor - * @property constructorArgs used to hold the properties as sent to the object's constructor. Passed in as a - * pre populated array as properties not present on the old constructor must be initialised in the factory - */ -abstract class EvolutionSerializer( - clazz: Type, - factory: SerializerFactory, - protected val oldReaders: Map, - override val kotlinConstructor: KFunction -) : ObjectSerializer(clazz, factory) { - // explicitly set as empty to indicate it's unused by this type of serializer - override val propertySerializers = PropertySerializersEvolution() - - /** - * Represents a parameter as would be passed to the constructor of the class as it was - * when it was serialised and NOT how that class appears now - * - * @param resultsIndex index into the constructor argument list where the read property - * should be placed - * @param property object to read the actual property value - */ - @KeepForDJVM - data class OldParam(var resultsIndex: Int, val property: PropertySerializer) { - fun readProperty(obj: Any?, schemas: SerializationSchemas, input: DeserializationInput, - new: Array, context: SerializationContext - ) = property.readProperty(obj, schemas, input, context).apply { - if (resultsIndex >= 0) { - new[resultsIndex] = this - } - } - - override fun toString(): String { - return "resultsIndex = $resultsIndex property = ${property.name}" - } - } - - companion object { - val logger = contextLogger() - - /** - * Unlike the generic deserialization case where we need to locate the primary constructor - * for the object (or our best guess) in the case of an object whose structure has changed - * since serialisation we need to attempt to locate a constructor that we can use. For example, - * its parameters match the serialised members and it will initialise any newly added - * elements. - * - * TODO: Type evolution - * TODO: rename annotation - */ - private fun getEvolverConstructor(type: Type, oldArgs: Map): KFunction? { - val clazz: Class<*> = type.asClass() - - if (!clazz.isConcreteClass) return null - - val oldArgumentSet = oldArgs.map { Pair(it.key as String?, it.value.property.resolvedType.asClass()) } - var maxConstructorVersion = Integer.MIN_VALUE - var constructor: KFunction? = null - - clazz.kotlin.constructors.forEach { - val version = it.findAnnotation()?.version ?: Integer.MIN_VALUE - - if (version > maxConstructorVersion && - oldArgumentSet.containsAll(it.parameters.map { v -> Pair(v.name, v.type.javaType.asClass()) }) - ) { - constructor = it - maxConstructorVersion = version - - with(logger) { - info("Select annotated constructor version=$version nparams=${it.parameters.size}") - debug{" params=${it.parameters}"} - } - } else if (version != Integer.MIN_VALUE){ - with(logger) { - info("Ignore annotated constructor version=$version nparams=${it.parameters.size}") - debug{" params=${it.parameters}"} - } - } - } - - // if we didn't get an exact match revert to existing behaviour, if the new parameters - // are not mandatory (i.e. nullable) things are fine - return constructor ?: run { - logger.info("Failed to find annotated historic constructor") - constructorForDeserialization(type) - } - } - - private fun makeWithConstructor( - new: ObjectSerializer, - factory: SerializerFactory, - constructor: KFunction, - readersAsSerialized: Map): AMQPSerializer { - - // Java doesn't care about nullability unless it's a primitive in which - // case it can't be referenced. Unfortunately whilst Kotlin does apply - // Nullability annotations we cannot use them here as they aren't - // retained at runtime so we cannot rely on the absence of - // any particular NonNullable annotation type to indicate cross - // compiler nullability - val isKotlin = (new.type.javaClass.declaredAnnotations.any { - it.annotationClass.qualifiedName == "kotlin.Metadata" - }) - - constructor.parameters.withIndex().forEach { - if ((readersAsSerialized[it.value.name!!] ?.apply { this.resultsIndex = it.index }) == null) { - // If there is no value in the byte stream to map to the parameter of the constructor - // this is ok IFF it's a Kotlin class and the parameter is non nullable OR - // its a Java class and the parameter is anything but an unboxed primitive. - // Otherwise we throw the error and leave - if ((isKotlin && !it.value.type.isMarkedNullable) - || (!isKotlin && isJavaPrimitive(it.value.type.jvmErasure.java)) - ) { - throw AMQPNotSerializableException( - new.type, - "New parameter \"${it.value.name}\" is mandatory, should be nullable for evolution " + - "to work, isKotlinClass=$isKotlin type=${it.value.type}") - } - } - } - return EvolutionSerializerViaConstructor(new.type, factory, readersAsSerialized, constructor) - } - - private fun makeWithSetters( - new: ObjectSerializer, - factory: SerializerFactory, - constructor: KFunction, - readersAsSerialized: Map, - classProperties: Map): AMQPSerializer { - val setters = propertiesForSerializationFromSetters(classProperties, - new.type, - factory).associateBy({ it.serializer.name }, { it }) - return EvolutionSerializerViaSetters(new.type, factory, readersAsSerialized, constructor, setters) - } - - /** - * Build a serialization object for deserialization only of objects serialised - * as different versions of a class. - * - * @param old is an object holding the schema that represents the object - * as it was serialised and the type descriptor of that type - * @param new is the Serializer built for the Class as it exists now, not - * how it was serialised and persisted. - * @param factory the [SerializerFactory] associated with the serialization - * context this serializer is being built for - */ - fun make(old: CompositeType, - new: ObjectSerializer, - factory: SerializerFactory - ): AMQPSerializer { - // The order in which the properties were serialised is important and must be preserved - val readersAsSerialized = LinkedHashMap() - old.fields.forEach { - readersAsSerialized[it.name] = try { - OldParam(-1, PropertySerializer.make(it.name, EvolutionPropertyReader(), - it.getTypeAsClass(factory.classloader), factory)) - } catch (e: ClassNotFoundException) { - throw AMQPNotSerializableException(new.type, e.message ?: "") - } - } - - // cope with the situation where a generic interface was serialised as a type, in such cases - // return the synthesised object which is, given the absence of a constructor, a no op - val constructor = getEvolverConstructor(new.type, readersAsSerialized) ?: return new - - val classProperties = new.type.asClass().propertyDescriptors() - - return if (classProperties.isNotEmpty() && constructor.parameters.isEmpty()) { - makeWithSetters(new, factory, constructor, readersAsSerialized, classProperties) - } else { - makeWithConstructor(new, factory, constructor, readersAsSerialized) - } - } - } - - override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput, - context: SerializationContext, debugIndent: Int - ) { - throw UnsupportedOperationException("It should be impossible to write an evolution serializer") - } -} - -class EvolutionSerializerViaConstructor( - clazz: Type, - factory: SerializerFactory, - oldReaders: Map, - kotlinConstructor: KFunction) : EvolutionSerializer(clazz, factory, oldReaders, kotlinConstructor) { - /** - * Unlike a normal [readObject] call where we simply apply the parameter deserialisers - * to the object list of values we need to map that list, which is ordered per the - * constructor of the original state of the object, we need to map the new parameter order - * of the current constructor onto that list inserting nulls where new parameters are - * encountered. - * - * TODO: Object references - */ - override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, - context: SerializationContext - ): Any { - if (obj !is List<*>) throw NotSerializableException("Body of described type is unexpected $obj") - - val constructorArgs : Array = arrayOfNulls(kotlinConstructor.parameters.size) - // *must* read all the parameters in the order they were serialized - oldReaders.values.zip(obj).map { it.first.readProperty(it.second, schemas, input, constructorArgs, context) } - - return javaConstructor?.newInstance(*(constructorArgs)) ?: throw NotSerializableException( - "Attempt to deserialize an interface: $clazz. Serialized form is invalid.") - } -} - -/** - * Specific instance of an [EvolutionSerializer] where the properties of the object are set via calling - * named setter functions on the instantiated object. - */ -class EvolutionSerializerViaSetters( - clazz: Type, - factory: SerializerFactory, - oldReaders: Map, - kotlinConstructor: KFunction, - private val setters: Map) : EvolutionSerializer(clazz, factory, oldReaders, kotlinConstructor) { - - override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, - context: SerializationContext - ): Any { - if (obj !is List<*>) throw NotSerializableException("Body of described type is unexpected $obj") - - val instance: Any = javaConstructor?.newInstance() ?: throw NotSerializableException( - "Failed to instantiate instance of object $clazz") - - // *must* read all the parameters in the order they were serialized - oldReaders.values.zip(obj).forEach { - // if that property still exists on the new object then set it - it.first.property.readProperty(it.second, schemas, input, context).apply { - setters[it.first.property.name]?.set(instance, this) - } - } - return instance - } -} - -/** - * Instances of this type are injected into a [SerializerFactory] at creation time to dictate the - * behaviour of evolution within that factory. Under normal circumstances this will simply - * be an object that returns an [EvolutionSerializer]. Of course, any implementation that - * extends this class can be written to invoke whatever behaviour is desired. - */ -interface EvolutionSerializerProvider { - fun getEvolutionSerializer( - factory: SerializerFactory, - typeNotation: TypeNotation, - newSerializer: AMQPSerializer, - schemas: SerializationSchemas): AMQPSerializer -} - -/** - * The normal use case for generating an [EvolutionSerializer]'s based on the differences - * between the received schema and the class as it exists now on the class path, - */ -@KeepForDJVM -object DefaultEvolutionSerializerProvider : EvolutionSerializerProvider { - override fun getEvolutionSerializer(factory: SerializerFactory, - typeNotation: TypeNotation, - newSerializer: AMQPSerializer, - schemas: SerializationSchemas): AMQPSerializer { - return factory.registerByDescriptor(typeNotation.descriptor.name!!) { - when (typeNotation) { - is CompositeType -> EvolutionSerializer.make(typeNotation, newSerializer as ObjectSerializer, factory) - is RestrictedType -> { - // The fingerprint of a generic collection can be changed through bug fixes to the - // fingerprinting function making it appear as if the class has altered whereas it hasn't. - // Given we don't support the evolution of these generic containers, if it appears - // one has been changed, simply return the original serializer and associate it with - // both the new and old fingerprint - if (newSerializer is CollectionSerializer || newSerializer is MapSerializer) { - newSerializer - } else if (newSerializer is EnumSerializer){ - EnumEvolutionSerializer.make(typeNotation, newSerializer, factory, schemas) - } - else { - loggerFor().error("typeNotation=${typeNotation.name} Need to evolve unsupported type") - throw NotSerializableException ("${typeNotation.name} cannot be evolved") - } - } - } - } - } -} \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerFactory.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerFactory.kt new file mode 100644 index 0000000000..49cb02dba4 --- /dev/null +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerFactory.kt @@ -0,0 +1,170 @@ +package net.corda.serialization.internal.amqp + +import net.corda.serialization.internal.model.* +import java.io.NotSerializableException + +/** + * A factory that knows how to create serialisers when there is a mismatch between the remote and local type schemas. + */ +interface EvolutionSerializerFactory { + + /** + * Compare the given [RemoteTypeInformation] and [LocalTypeInformation], and construct (if needed) an evolution + * serialiser that can take properties serialised in the remote schema and construct an object conformant to the local schema. + * + * Will return null if no evolution is necessary, because the schemas are compatible. + */ + fun getEvolutionSerializer( + remote: RemoteTypeInformation, + local: LocalTypeInformation): AMQPSerializer? +} + +class EvolutionSerializationException(remoteTypeInformation: RemoteTypeInformation, reason: String) + : NotSerializableException( + """ + Cannot construct evolution serializer for remote type ${remoteTypeInformation.prettyPrint(false)} + + $reason + """.trimIndent() +) + +class DefaultEvolutionSerializerFactory( + private val localSerializerFactory: LocalSerializerFactory, + private val classLoader: ClassLoader, + private val mustPreserveDataWhenEvolving: Boolean): EvolutionSerializerFactory { + + override fun getEvolutionSerializer(remote: RemoteTypeInformation, + local: LocalTypeInformation): AMQPSerializer? = + when(remote) { + is RemoteTypeInformation.Composable -> + if (local is LocalTypeInformation.Composable) remote.getEvolutionSerializer(local) + else null + is RemoteTypeInformation.AnEnum -> + if (local is LocalTypeInformation.AnEnum) remote.getEvolutionSerializer(local) + else null + else -> null + } + + private fun RemoteTypeInformation.Composable.getEvolutionSerializer( + localTypeInformation: LocalTypeInformation.Composable): AMQPSerializer? { + // The no-op case: although the fingerprints don't match for some reason, we have compatible signatures. + // This might happen because of inconsistent type erasure, changes to the behaviour of the fingerprinter, + // or changes to the type itself - such as adding an interface - that do not change its serialisation/deserialisation + // signature. + if (propertyNamesMatch(localTypeInformation)) { + // Make sure types are assignment-compatible, and return the local serializer for the type. + validateCompatibility(localTypeInformation) + return null + } + + // Failing that, we have to create an evolution serializer. + val bestMatchEvolutionConstructor = findEvolverConstructor(localTypeInformation.evolutionConstructors, properties) + val constructorForEvolution = bestMatchEvolutionConstructor?.constructor ?: localTypeInformation.constructor + val evolverProperties = bestMatchEvolutionConstructor?.properties ?: localTypeInformation.properties + + validateEvolvability(evolverProperties) + + return buildComposableEvolutionSerializer(localTypeInformation, constructorForEvolution, evolverProperties) + } + + private fun RemoteTypeInformation.Composable.propertyNamesMatch(localTypeInformation: LocalTypeInformation.Composable): Boolean = + properties.keys == localTypeInformation.properties.keys + + private fun RemoteTypeInformation.Composable.validateCompatibility(localTypeInformation: LocalTypeInformation.Composable) { + properties.asSequence().zip(localTypeInformation.properties.values.asSequence()).forEach { (remote, localProperty) -> + val (name, remoteProperty) = remote + val localClass = localProperty.type.observedType.asClass() + val remoteClass = remoteProperty.type.typeIdentifier.getLocalType(classLoader).asClass() + + if (!localClass.isAssignableFrom(remoteClass)) { + throw EvolutionSerializationException(this, + "Local type $localClass of property $name is not assignable from remote type $remoteClass") + } + } + } + + // Find the evolution constructor with the highest version number whose parameters are all assignable from the + // provided property types. + private fun findEvolverConstructor(constructors: List, + properties: Map): EvolutionConstructorInformation? { + val propertyTypes = properties.mapValues { (_, info) -> info.type.typeIdentifier.getLocalType(classLoader).asClass() } + + // Evolver constructors are listed in ascending version order, so we just want the last that matches. + return constructors.lastOrNull { (_, evolverProperties) -> + // We have a match if all mandatory evolver properties have a type-compatible property in the remote type. + evolverProperties.all { (name, evolverProperty) -> + val propertyType = propertyTypes[name] + if (propertyType == null) !evolverProperty.isMandatory + else evolverProperty.type.observedType.asClass().isAssignableFrom(propertyType) + } + } + } + + private fun RemoteTypeInformation.Composable.validateEvolvability(localProperties: Map) { + val remotePropertyNames = properties.keys + val localPropertyNames = localProperties.keys + val deletedProperties = remotePropertyNames - localPropertyNames + val newProperties = localPropertyNames - remotePropertyNames + + // Here is where we can exercise a veto on evolutions that remove properties. + if (deletedProperties.isNotEmpty() && mustPreserveDataWhenEvolving) + throw EvolutionSerializationException(this, + "Property ${deletedProperties.first()} of remote ContractState type is not present in local type, " + + "and context is configured to prevent forwards-compatible deserialization.") + + // Check mandatory-ness of constructor-set properties. + newProperties.forEach { propertyName -> + if (localProperties[propertyName]!!.mustBeProvided) throw EvolutionSerializationException( + this, + "Mandatory property $propertyName of local type is not present in remote type - " + + "did someone remove a property from the schema without considering old clients?") + } + } + + private val LocalPropertyInformation.mustBeProvided: Boolean get() = when(this) { + is LocalPropertyInformation.ConstructorPairedProperty -> isMandatory + is LocalPropertyInformation.PrivateConstructorPairedProperty -> isMandatory + else -> false + } + + private fun RemoteTypeInformation.AnEnum.getEvolutionSerializer( + localTypeInformation: LocalTypeInformation.AnEnum): AMQPSerializer? { + if (members == localTypeInformation.members) return null + + val remoteTransforms = transforms + val localTransforms = localTypeInformation.getEnumTransforms(localSerializerFactory) + val transforms = if (remoteTransforms.size > localTransforms.size) remoteTransforms else localTransforms + + val localOrdinals = localTypeInformation.members.asSequence().mapIndexed { ord, member -> member to ord }.toMap() + val remoteOrdinals = members.asSequence().mapIndexed { ord, member -> member to ord }.toMap() + val rules = transforms.defaults + transforms.renames + + // We just trust our transformation rules not to contain cycles here. + tailrec fun findLocal(remote: String): String = + if (remote in localOrdinals) remote + else findLocal(rules[remote] ?: throw EvolutionSerializationException( + this, + "Cannot resolve local enum member $remote to a member of ${localOrdinals.keys} using rules $rules" + )) + + val conversions = members.associate { it to findLocal(it) } + val convertedOrdinals = remoteOrdinals.asSequence().map { (member, ord) -> ord to conversions[member]!! }.toMap() + if (localOrdinals.any { (name, ordinal) -> convertedOrdinals[ordinal] != name }) + throw EvolutionSerializationException( + this, + "Constants have been reordered, additions must be appended to the end") + + return EnumEvolutionSerializer(localTypeInformation.observedType, localSerializerFactory, conversions, localOrdinals) + } + + private fun RemoteTypeInformation.Composable.buildComposableEvolutionSerializer( + localTypeInformation: LocalTypeInformation.Composable, + constructor: LocalConstructorInformation, + properties: Map): AMQPSerializer = + EvolutionObjectSerializer.make( + localTypeInformation, + this, + constructor, + properties, + classLoader) +} \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/FingerPrinter.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/FingerPrinter.kt deleted file mode 100644 index d03e3e6da1..0000000000 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/FingerPrinter.kt +++ /dev/null @@ -1,202 +0,0 @@ -package net.corda.serialization.internal.amqp - -import com.google.common.hash.Hasher -import com.google.common.hash.Hashing -import net.corda.core.KeepForDJVM -import net.corda.core.internal.isConcreteClass -import net.corda.core.internal.kotlinObjectInstance -import net.corda.core.utilities.toBase64 -import net.corda.serialization.internal.amqp.SerializerFactory.Companion.isPrimitive -import java.lang.reflect.* -import java.util.* - -/** - * Should be implemented by classes which wish to provide pluggable fingerprinting on types for a [SerializerFactory] - */ -@KeepForDJVM -interface FingerPrinter { - /** - * Return a unique identifier for a type, usually this will take into account the constituent elements - * of said type such that any modification to any sub element wll generate a different fingerprint - */ - fun fingerprint(type: Type): String -} - -/** - * Implementation of the finger printing mechanism used by default - */ -@KeepForDJVM -class SerializerFingerPrinter(val factory: SerializerFactory) : FingerPrinter { - - /** - * The method generates a fingerprint for a given JVM [Type] that should be unique to the schema representation. - * Thus it only takes into account properties and types and only supports the same object graph subset as the overall - * serialization code. - * - * The idea being that even for two classes that share the same name but differ in a minor way, the fingerprint will be - * different. - */ - override fun fingerprint(type: Type): String = FingerPrintingState(factory).fingerprint(type) -} - -// Representation of the current state of fingerprinting -internal class FingerPrintingState(private val factory: SerializerFactory) { - - companion object { - private const val ARRAY_HASH: String = "Array = true" - private const val ENUM_HASH: String = "Enum = true" - private const val ALREADY_SEEN_HASH: String = "Already seen = true" - private const val NULLABLE_HASH: String = "Nullable = true" - private const val NOT_NULLABLE_HASH: String = "Nullable = false" - private const val ANY_TYPE_HASH: String = "Any type = true" - } - - private val typesSeen: MutableSet = mutableSetOf() - private var currentContext: Type? = null - private var hasher: Hasher = newDefaultHasher() - - // Fingerprint the type recursively, and return the encoded fingerprint written into the hasher. - fun fingerprint(type: Type) = fingerprintType(type).hasher.fingerprint - - // This method concatenates various elements of the types recursively as unencoded strings into the hasher, - // effectively creating a unique string for a type which we then hash in the calling function above. - private fun fingerprintType(type: Type): FingerPrintingState = apply { - // Don't go round in circles. - if (hasSeen(type)) append(ALREADY_SEEN_HASH) - else ifThrowsAppend( - { type.typeName }, - { - typesSeen.add(type) - currentContext = type - fingerprintNewType(type) - }) - } - - // For a type we haven't seen before, determine the correct path depending on the type of type it is. - private fun fingerprintNewType(type: Type) = when (type) { - is ParameterizedType -> fingerprintParameterizedType(type) - // Previously, we drew a distinction between TypeVariable, WildcardType, and AnyType, changing - // the signature of the fingerprinted object. This, however, doesn't work as it breaks bi- - // directional fingerprints. That is, fingerprinting a concrete instance of a generic - // type (Example), creates a different fingerprint from the generic type itself (Example) - // - // On serialization Example is treated as Example, a TypeVariable - // On deserialisation it is seen as Example, A WildcardType *and* a TypeVariable - // Note: AnyType is a special case of WildcardType used in other parts of the - // serializer so both cases need to be dealt with here - // - // If we treat these types as fundamentally different and alter the fingerprint we will - // end up breaking into the evolver when we shouldn't or, worse, evoking the carpenter. - is SerializerFactory.AnyType, - is WildcardType, - is TypeVariable<*> -> append("?$ANY_TYPE_HASH") - is Class<*> -> fingerprintClass(type) - is GenericArrayType -> fingerprintType(type.genericComponentType).append(ARRAY_HASH) - else -> throw AMQPNotSerializableException(type, "Don't know how to hash") - } - - private fun fingerprintClass(type: Class<*>) = when { - type.isArray -> fingerprintType(type.componentType).append(ARRAY_HASH) - type.isPrimitiveOrCollection -> append(type.name) - type.isEnum -> fingerprintEnum(type) - else -> fingerprintWithCustomSerializerOrElse(type, type) { - if (type.kotlinObjectInstance != null) append(type.name) - else fingerprintObject(type) - } - } - - private fun fingerprintParameterizedType(type: ParameterizedType) { - // Hash the rawType + params - type.asClass().let { clazz -> - if (clazz.isCollectionOrMap) append(clazz.name) - else fingerprintWithCustomSerializerOrElse(clazz, type) { - fingerprintObject(type) - } - } - - // ...and concatenate the type data for each parameter type. - type.actualTypeArguments.forEach { paramType -> - fingerprintType(paramType) - } - } - - private fun fingerprintObject(type: Type) { - // Hash the class + properties + interfaces - append(type.asClass().name) - - orderedPropertiesForSerialization(type).forEach { prop -> - fingerprintType(prop.serializer.resolvedType) - fingerprintPropSerialiser(prop) - } - - interfacesForSerialization(type, factory).forEach { iface -> - fingerprintType(iface) - } - } - - // ensures any change to the enum (adding constants) will trigger the need for evolution - private fun fingerprintEnum(type: Class<*>) { - append(type.enumConstants.joinToString()) - append(type.name) - append(ENUM_HASH) - } - - private fun fingerprintPropSerialiser(prop: PropertyAccessor) { - append(prop.serializer.name) - append(if (prop.serializer.mandatory) NOT_NULLABLE_HASH - else NULLABLE_HASH) - } - - // Write the given character sequence into the hasher. - private fun append(chars: CharSequence) { - hasher = hasher.putUnencodedChars(chars) - } - - // Give any custom serializers loaded into the factory the chance to supply their own type-descriptors - private fun fingerprintWithCustomSerializerOrElse( - clazz: Class<*>, - declaredType: Type, - defaultAction: () -> Unit) - : Unit = factory.findCustomSerializer(clazz, declaredType)?.let { - append(it.typeDescriptor) - } ?: defaultAction() - - // Test whether we are in a state in which we have already seen the given type. - // - // We don't include Example and Example where type is ? or T in this otherwise we - // generate different fingerprints for class Outer(val a: Inner) when serialising - // and deserializing (assuming deserialization is occurring in a factory that didn't - // serialise the object in the first place (and thus the cache lookup fails). This is also - // true of Any, where we need Example and Example to have the same fingerprint - private fun hasSeen(type: Type) = (type in typesSeen) - && (type !== SerializerFactory.AnyType) - && (type !is TypeVariable<*>) - && (type !is WildcardType) - - private fun orderedPropertiesForSerialization(type: Type): List { - return propertiesForSerialization( - if (type.asClass().isConcreteClass) constructorForDeserialization(type) else null, - currentContext ?: type, - factory).serializationOrder - } - -} - -// region Utility functions - -// Create a new instance of the [Hasher] used for fingerprinting by the default [SerializerFingerPrinter] -private fun newDefaultHasher() = Hashing.murmur3_128().newHasher() - -// We obtain a fingerprint from a [Hasher] by taking the Base 64 encoding of its hash bytes -private val Hasher.fingerprint get() = hash().asBytes().toBase64() - -internal fun fingerprintForDescriptors(vararg typeDescriptors: String): String = - newDefaultHasher().putUnencodedChars(typeDescriptors.joinToString()).fingerprint - -private val Class<*>.isCollectionOrMap get() = - (Collection::class.java.isAssignableFrom(this) || Map::class.java.isAssignableFrom(this)) - && !EnumSet::class.java.isAssignableFrom(this) - -private val Class<*>.isPrimitiveOrCollection get() = - isPrimitive(this) || isCollectionOrMap -// endregion diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/LocalSerializerFactory.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/LocalSerializerFactory.kt new file mode 100644 index 0000000000..9b1b530764 --- /dev/null +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/LocalSerializerFactory.kt @@ -0,0 +1,226 @@ +package net.corda.serialization.internal.amqp + +import net.corda.core.internal.kotlinObjectInstance +import net.corda.core.serialization.ClassWhitelist +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.debug +import net.corda.core.utilities.trace +import net.corda.serialization.internal.model.* +import org.apache.qpid.proton.amqp.Symbol +import java.io.NotSerializableException +import java.lang.reflect.ParameterizedType +import java.lang.reflect.Type +import java.lang.reflect.WildcardType +import java.util.* +import javax.annotation.concurrent.ThreadSafe + +/** + * A factory that handles the serialisation and deserialisation of [Type]s visible from a given [ClassLoader]. + * + * Unlike the [RemoteSerializerFactory], which deals with types for which we have [Schema] information and serialised data, + * the [LocalSerializerFactory] deals with types for which we have a Java [Type] (and perhaps some in-memory data, from which + * we can discover the actual [Class] we are working with. + */ +interface LocalSerializerFactory { + /** + * The [ClassWhitelist] used by this factory. Classes must be whitelisted for serialization, because they are expected + * to be written in a secure manner. + */ + val whitelist: ClassWhitelist + + /** + * The [ClassLoader] used by this factory. + */ + val classloader: ClassLoader + + /** + * Obtain an [AMQPSerializer] for an object of actual type [actualClass], and declared type [declaredType]. + */ + fun get(actualClass: Class<*>, declaredType: Type): AMQPSerializer + + /** + * Obtain an [AMQPSerializer] for the [declaredType]. + */ + fun get(declaredType: Type): AMQPSerializer = get(getTypeInformation(declaredType)) + + /** + * Obtain an [AMQPSerializer] for the type having the given [typeInformation]. + */ + fun get(typeInformation: LocalTypeInformation): AMQPSerializer + + /** + * Obtain [LocalTypeInformation] for the given [Type]. + */ + fun getTypeInformation(type: Type): LocalTypeInformation + + /** + * Use the [FingerPrinter] to create a type descriptor for the given [type]. + */ + fun createDescriptor(type: Type): Symbol = createDescriptor(getTypeInformation(type)) + + /** + * Use the [FingerPrinter] to create a type descriptor for the given [typeInformation]. + */ + fun createDescriptor(typeInformation: LocalTypeInformation): Symbol + + /** + * Obtain or register [Transform]s for the given class [name]. + * + * Eventually this information should be moved into the [LocalTypeInformation] for the type. + */ + fun getOrBuildTransform(name: String, builder: () -> EnumMap>): + EnumMap> +} + +/** + * A [LocalSerializerFactory] equipped with a [LocalTypeModel] and a [FingerPrinter] to help it build fingerprint-based descriptors + * and serializers for local types. + */ +@ThreadSafe +class DefaultLocalSerializerFactory( + override val whitelist: ClassWhitelist, + private val typeModel: LocalTypeModel, + private val fingerPrinter: FingerPrinter, + override val classloader: ClassLoader, + private val descriptorBasedSerializerRegistry: DescriptorBasedSerializerRegistry, + private val customSerializerRegistry: CustomSerializerRegistry, + private val onlyCustomSerializers: Boolean) + : LocalSerializerFactory { + + companion object { + val logger = contextLogger() + } + + private val transformsCache: MutableMap>> = DefaultCacheProvider.createCache() + private val serializersByType: MutableMap> = DefaultCacheProvider.createCache() + + override fun createDescriptor(typeInformation: LocalTypeInformation): Symbol = + Symbol.valueOf("$DESCRIPTOR_DOMAIN:${fingerPrinter.fingerprint(typeInformation)}") + + override fun getTypeInformation(type: Type): LocalTypeInformation = typeModel.inspect(type) + + override fun getOrBuildTransform(name: String, builder: () -> EnumMap>): + EnumMap> = + transformsCache.computeIfAbsent(name) { _ -> builder() } + + override fun get(typeInformation: LocalTypeInformation): AMQPSerializer = + get(typeInformation.observedType, typeInformation) + + private fun make(typeInformation: LocalTypeInformation, build: () -> AMQPSerializer) = + make(typeInformation.typeIdentifier, build) + + private fun make(typeIdentifier: TypeIdentifier, build: () -> AMQPSerializer) = + serializersByType.computeIfAbsent(typeIdentifier) { _ -> build() } + + private fun get(declaredType: Type, localTypeInformation: LocalTypeInformation): AMQPSerializer { + val declaredClass = declaredType.asClass() + + // can be useful to enable but will be *extremely* chatty if you do + logger.trace { "Get Serializer for $declaredClass ${declaredType.typeName}" } + + return when(localTypeInformation) { + is LocalTypeInformation.ACollection -> makeDeclaredCollection(localTypeInformation) + is LocalTypeInformation.AMap -> makeDeclaredMap(localTypeInformation) + is LocalTypeInformation.AnEnum -> makeDeclaredEnum(localTypeInformation, declaredType, declaredClass) + else -> makeClassSerializer(declaredClass, declaredType, declaredType, localTypeInformation) + }.also { serializer -> descriptorBasedSerializerRegistry[serializer.typeDescriptor.toString()] = serializer } + } + + private fun makeDeclaredEnum(localTypeInformation: LocalTypeInformation, declaredType: Type, declaredClass: Class<*>): AMQPSerializer = + make(localTypeInformation) { + whitelist.requireWhitelisted(declaredType) + EnumSerializer(declaredType, declaredClass, this) + } + + private fun makeActualEnum(localTypeInformation: LocalTypeInformation, declaredType: Type, declaredClass: Class<*>): AMQPSerializer = + make(localTypeInformation) { + whitelist.requireWhitelisted(declaredType) + EnumSerializer(declaredType, declaredClass, this) + } + + private fun makeDeclaredCollection(localTypeInformation: LocalTypeInformation.ACollection): AMQPSerializer { + val resolved = CollectionSerializer.resolveDeclared(localTypeInformation) + return make(resolved) { + CollectionSerializer(resolved.typeIdentifier.getLocalType(classloader) as ParameterizedType, this) + } + } + + private fun makeDeclaredMap(localTypeInformation: LocalTypeInformation.AMap): AMQPSerializer { + val resolved = MapSerializer.resolveDeclared(localTypeInformation) + return make(resolved) { + MapSerializer(resolved.typeIdentifier.getLocalType(classloader) as ParameterizedType, this) + } + } + + override fun get(actualClass: Class<*>, declaredType: Type): AMQPSerializer { + // can be useful to enable but will be *extremely* chatty if you do + logger.trace { "Get Serializer for $actualClass ${declaredType.typeName}" } + + val declaredClass = declaredType.asClass() + val actualType: Type = inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType + val declaredTypeInformation = typeModel.inspect(declaredType) + val actualTypeInformation = typeModel.inspect(actualType) + + return when(actualTypeInformation) { + is LocalTypeInformation.ACollection -> makeActualCollection(actualClass,declaredTypeInformation as? LocalTypeInformation.ACollection ?: actualTypeInformation) + is LocalTypeInformation.AMap -> makeActualMap(declaredType, actualClass,declaredTypeInformation as? LocalTypeInformation.AMap ?: actualTypeInformation) + is LocalTypeInformation.AnEnum -> makeActualEnum(actualTypeInformation, actualType, actualClass) + else -> makeClassSerializer(actualClass, actualType, declaredType, actualTypeInformation) + }.also { serializer -> descriptorBasedSerializerRegistry[serializer.typeDescriptor.toString()] = serializer } + } + + private fun makeActualMap(declaredType: Type, actualClass: Class<*>, typeInformation: LocalTypeInformation.AMap): AMQPSerializer { + declaredType.asClass().checkSupportedMapType() + val resolved = MapSerializer.resolveActual(actualClass, typeInformation) + return make(resolved) { + MapSerializer(resolved.typeIdentifier.getLocalType(classloader) as ParameterizedType, this) + } + } + + private fun makeActualCollection(actualClass: Class<*>, typeInformation: LocalTypeInformation.ACollection): AMQPSerializer { + val resolved = CollectionSerializer.resolveActual(actualClass, typeInformation) + + return serializersByType.computeIfAbsent(resolved.typeIdentifier) { + CollectionSerializer(resolved.typeIdentifier.getLocalType(classloader) as ParameterizedType, this) + } + } + + private fun makeClassSerializer( + clazz: Class<*>, + type: Type, + declaredType: Type, + typeInformation: LocalTypeInformation + ): AMQPSerializer = make(typeInformation) { + logger.debug { "class=${clazz.simpleName}, type=$type is a composite type" } + when { + clazz.isSynthetic -> // Explicitly ban synthetic classes, we have no way of recreating them when deserializing. This also + // captures Lambda expressions and other anonymous functions + throw AMQPNotSerializableException( + type, + "Serializer does not support synthetic classes") + AMQPTypeIdentifiers.isPrimitive(typeInformation.typeIdentifier) -> AMQPPrimitiveSerializer(clazz) + else -> customSerializerRegistry.findCustomSerializer(clazz, declaredType) ?: + makeNonCustomSerializer(type, typeInformation, clazz) + } + } + + private fun makeNonCustomSerializer(type: Type, typeInformation: LocalTypeInformation, clazz: Class<*>): AMQPSerializer = when { + onlyCustomSerializers -> throw AMQPNotSerializableException(type, "Only allowing custom serializers") + type.isArray() -> + if (clazz.componentType.isPrimitive) PrimArraySerializer.make(type, this) + else { + ArraySerializer.make(type, this) + } + else -> { + val singleton = clazz.kotlinObjectInstance + if (singleton != null) { + whitelist.requireWhitelisted(clazz) + SingletonSerializer(clazz, singleton, this) + } else { + whitelist.requireWhitelisted(type) + ObjectSerializer.make(typeInformation, this) + } + } + } + +} \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/MapSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/MapSerializer.kt index 742c6a84ed..2e00e8d206 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/MapSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/MapSerializer.kt @@ -4,6 +4,8 @@ import net.corda.core.KeepForDJVM import net.corda.core.StubOutForDJVM import net.corda.core.internal.uncheckedCast import net.corda.core.serialization.SerializationContext +import net.corda.serialization.internal.model.LocalTypeInformation +import net.corda.serialization.internal.model.TypeIdentifier import org.apache.qpid.proton.amqp.Symbol import org.apache.qpid.proton.codec.Data import java.io.NotSerializableException @@ -18,11 +20,10 @@ private typealias MapCreationFunction = (Map<*, *>) -> Map<*, *> * Serialization / deserialization of certain supported [Map] types. */ @KeepForDJVM -class MapSerializer(private val declaredType: ParameterizedType, factory: SerializerFactory) : AMQPSerializer { - override val type: Type = (declaredType as? DeserializedParameterizedType) - ?: DeserializedParameterizedType.make(SerializerFactory.nameForType(declaredType), factory.classloader) - override val typeDescriptor: Symbol = Symbol.valueOf( - "$DESCRIPTOR_DOMAIN:${factory.fingerPrinter.fingerprint(type)}") +class MapSerializer(private val declaredType: ParameterizedType, factory: LocalSerializerFactory) : AMQPSerializer { + override val type: Type = declaredType + + override val typeDescriptor: Symbol = factory.createDescriptor(type) companion object { // NB: Order matters in this map, the most specific classes should be listed at the end @@ -39,29 +40,43 @@ class MapSerializer(private val declaredType: ParameterizedType, factory: Serial } )) + private val supportedTypeIdentifiers = supportedTypes.keys.asSequence() + .map { TypeIdentifier.forGenericType(it) }.toSet() + private fun findConcreteType(clazz: Class<*>): MapCreationFunction { return supportedTypes[clazz] ?: throw AMQPNotSerializableException(clazz, "Unsupported map type $clazz.") } - fun deriveParameterizedType(declaredType: Type, declaredClass: Class<*>, actualClass: Class<*>?): ParameterizedType { - declaredClass.checkSupportedMapType() - if (supportedTypes.containsKey(declaredClass)) { - // Simple case - it is already known to be a map. - return deriveParametrizedType(declaredType, uncheckedCast(declaredClass)) - } else if (actualClass != null && Map::class.java.isAssignableFrom(actualClass)) { - // Declared class is not map, but [actualClass] is - represent it accordingly. - val mapClass = findMostSuitableMapType(actualClass) - return deriveParametrizedType(declaredType, mapClass) - } + fun resolveDeclared(declaredTypeInformation: LocalTypeInformation.AMap): LocalTypeInformation.AMap { + declaredTypeInformation.observedType.asClass().checkSupportedMapType() + if (supportedTypeIdentifiers.contains(declaredTypeInformation.typeIdentifier.erased)) + return if (!declaredTypeInformation.isErased) declaredTypeInformation + else declaredTypeInformation.withParameters(LocalTypeInformation.Unknown, LocalTypeInformation.Unknown) - throw AMQPNotSerializableException(declaredType, - "Cannot derive map type for declaredType=\"$declaredType\", declaredClass=\"$declaredClass\", actualClass=\"$actualClass\"") + throw NotSerializableException("Cannot derive map type for declared type " + + declaredTypeInformation.prettyPrint(false)) } - private fun deriveParametrizedType(declaredType: Type, collectionClass: Class>): ParameterizedType = - (declaredType as? ParameterizedType) - ?: DeserializedParameterizedType(collectionClass, arrayOf(SerializerFactory.AnyType, SerializerFactory.AnyType)) + fun resolveActual(actualClass: Class<*>, declaredTypeInformation: LocalTypeInformation.AMap): LocalTypeInformation.AMap { + declaredTypeInformation.observedType.asClass().checkSupportedMapType() + if (supportedTypeIdentifiers.contains(declaredTypeInformation.typeIdentifier.erased)) { + return if (!declaredTypeInformation.isErased) declaredTypeInformation + else declaredTypeInformation.withParameters(LocalTypeInformation.Unknown, LocalTypeInformation.Unknown) + } + val mapClass = findMostSuitableMapType(actualClass) + val erasedInformation = LocalTypeInformation.AMap( + mapClass, + TypeIdentifier.forClass(mapClass), + LocalTypeInformation.Unknown, LocalTypeInformation.Unknown) + + return when(declaredTypeInformation.typeIdentifier) { + is TypeIdentifier.Parameterised -> erasedInformation.withParameters( + declaredTypeInformation.keyType, + declaredTypeInformation.valueType) + else -> erasedInformation.withParameters(LocalTypeInformation.Unknown, LocalTypeInformation.Unknown) + } + } private fun findMostSuitableMapType(actualClass: Class<*>): Class> = MapSerializer.supportedTypes.keys.findLast { it.isAssignableFrom(actualClass) }!! @@ -69,7 +84,7 @@ class MapSerializer(private val declaredType: ParameterizedType, factory: Serial private val concreteBuilder: MapCreationFunction = findConcreteType(declaredType.rawType as Class<*>) - private val typeNotation: TypeNotation = RestrictedType(SerializerFactory.nameForType(declaredType), null, emptyList(), "map", Descriptor(typeDescriptor), emptyList()) + private val typeNotation: TypeNotation = RestrictedType(AMQPTypeIdentifiers.nameForType(declaredType), null, emptyList(), "map", Descriptor(typeDescriptor), emptyList()) private val inboundKeyType = declaredType.actualTypeArguments[0] private val outboundKeyType = resolveTypeVariables(inboundKeyType, null) @@ -108,7 +123,6 @@ class MapSerializer(private val declaredType: ParameterizedType, factory: Serial override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext ): Any = ifThrowsAppend({ declaredType.typeName }) { - // TODO: General generics question. Do we need to validate that entries in Maps and Collections match the generic type? Is it a security hole? val entries: Iterable> = (obj as Map<*, *>).map { readEntry(schemas, input, it, context) } concreteBuilder(entries.toMap()) } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ObjectBuilder.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ObjectBuilder.kt new file mode 100644 index 0000000000..2b32e5482b --- /dev/null +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ObjectBuilder.kt @@ -0,0 +1,112 @@ +package net.corda.serialization.internal.amqp + +import net.corda.serialization.internal.model.* +import java.io.NotSerializableException + +interface ObjectBuilder { + + companion object { + fun makeProvider(typeInformation: LocalTypeInformation.Composable): () -> ObjectBuilder = + makeProvider(typeInformation.typeIdentifier, typeInformation.constructor, typeInformation.properties) + + fun makeProvider(typeIdentifier: TypeIdentifier, constructor: LocalConstructorInformation, properties: Map): () -> ObjectBuilder { + val nonCalculatedProperties = properties.asSequence() + .filterNot { (name, property) -> property.isCalculated } + .sortedBy { (name, _) -> name } + .map { (_, property) -> property } + .toList() + + val propertyIndices = nonCalculatedProperties.mapNotNull { + when(it) { + is LocalPropertyInformation.ConstructorPairedProperty -> it.constructorSlot.parameterIndex + is LocalPropertyInformation.PrivateConstructorPairedProperty -> it.constructorSlot.parameterIndex + else -> null + } + }.toIntArray() + + if (propertyIndices.isNotEmpty()) { + if (propertyIndices.size != nonCalculatedProperties.size) { + throw NotSerializableException( + "Some but not all properties of ${typeIdentifier.prettyPrint(false)} " + + "are constructor-based") + } + return { ConstructorBasedObjectBuilder(constructor, propertyIndices) } + } + + val getterSetter = nonCalculatedProperties.filterIsInstance() + return { SetterBasedObjectBuilder(constructor, getterSetter) } + } + } + + fun initialize() + fun populate(slot: Int, value: Any?) + fun build(): Any +} + +class SetterBasedObjectBuilder( + val constructor: LocalConstructorInformation, + val properties: List): ObjectBuilder { + + private lateinit var target: Any + + override fun initialize() { + target = constructor.observedMethod.call() + } + + override fun populate(slot: Int, value: Any?) { + properties[slot].observedSetter.invoke(target, value) + } + + override fun build(): Any = target +} + +class ConstructorBasedObjectBuilder( + val constructor: LocalConstructorInformation, + val parameterIndices: IntArray): ObjectBuilder { + + private val params = arrayOfNulls(parameterIndices.size) + + override fun initialize() {} + + override fun populate(slot: Int, value: Any?) { + if (slot >= parameterIndices.size) { + assert(false) + } + val parameterIndex = parameterIndices[slot] + if (parameterIndex >= params.size) { + assert(false) + } + params[parameterIndex] = value + } + + override fun build(): Any = constructor.observedMethod.call(*params) +} + +class EvolutionObjectBuilder(private val localBuilder: ObjectBuilder, val slotAssignments: IntArray): ObjectBuilder { + + companion object { + fun makeProvider(typeIdentifier: TypeIdentifier, constructor: LocalConstructorInformation, localProperties: Map, providedProperties: List): () -> ObjectBuilder { + val localBuilderProvider = ObjectBuilder.makeProvider(typeIdentifier, constructor, localProperties) + val localPropertyIndices = localProperties.asSequence() + .filter { (_, property) -> !property.isCalculated } + .mapIndexed { slot, (name, _) -> name to slot } + .toMap() + + val reroutedIndices = providedProperties.map { propertyName -> localPropertyIndices[propertyName] ?: -1 } + .toIntArray() + + return { EvolutionObjectBuilder(localBuilderProvider(), reroutedIndices) } + } + } + + override fun initialize() { + localBuilder.initialize() + } + + override fun populate(slot: Int, value: Any?) { + val slotAssignment = slotAssignments[slot] + if (slotAssignment != -1) localBuilder.populate(slotAssignment, value) + } + + override fun build(): Any = localBuilder.build() +} \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ObjectSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ObjectSerializer.kt index 9ae529b608..5d422e9156 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ObjectSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ObjectSerializer.kt @@ -1,185 +1,204 @@ package net.corda.serialization.internal.amqp -import net.corda.core.internal.isConcreteClass import net.corda.core.serialization.SerializationContext -import net.corda.core.serialization.serialize -import net.corda.core.utilities.contextLogger -import net.corda.core.utilities.trace -import net.corda.serialization.internal.amqp.SerializerFactory.Companion.nameForType +import net.corda.serialization.internal.model.* import org.apache.qpid.proton.amqp.Symbol import org.apache.qpid.proton.codec.Data import java.io.NotSerializableException -import java.lang.reflect.Constructor -import java.lang.reflect.InvocationTargetException import java.lang.reflect.Type -import kotlin.reflect.jvm.javaConstructor -/** - * Responsible for serializing and deserializing a regular object instance via a series of properties - * (matched with a constructor). - */ -open class ObjectSerializer(val clazz: Type, factory: SerializerFactory) : AMQPSerializer { - override val type: Type get() = clazz - open val kotlinConstructor = if (clazz.asClass().isConcreteClass) constructorForDeserialization(clazz) else null - val javaConstructor by lazy { kotlinConstructor?.javaConstructor } +interface ObjectSerializer : AMQPSerializer { + + val propertySerializers: Map + val fields: List companion object { - private val logger = contextLogger() + fun make(typeInformation: LocalTypeInformation, factory: LocalSerializerFactory): ObjectSerializer { + val typeDescriptor = factory.createDescriptor(typeInformation) + val typeNotation = TypeNotationGenerator.getTypeNotation(typeInformation, typeDescriptor) + + return when (typeInformation) { + is LocalTypeInformation.Composable -> + makeForComposable(typeInformation, typeNotation, typeDescriptor, factory) + is LocalTypeInformation.AnInterface, + is LocalTypeInformation.Abstract -> + makeForAbstract(typeNotation, typeInformation, typeDescriptor, factory) + else -> throw NotSerializableException("Cannot build object serializer for $typeInformation") + } + } + + private fun makeForAbstract(typeNotation: CompositeType, + typeInformation: LocalTypeInformation, + typeDescriptor: Symbol, + factory: LocalSerializerFactory): AbstractObjectSerializer { + val propertySerializers = makePropertySerializers(typeInformation.propertiesOrEmptyMap, factory) + val writer = ComposableObjectWriter(typeNotation, typeInformation.interfacesOrEmptyList, propertySerializers) + return AbstractObjectSerializer(typeInformation.observedType, typeDescriptor, propertySerializers, + typeNotation.fields, writer) + } + + private fun makeForComposable(typeInformation: LocalTypeInformation.Composable, + typeNotation: CompositeType, + typeDescriptor: Symbol, + factory: LocalSerializerFactory): ComposableObjectSerializer { + val propertySerializers = makePropertySerializers(typeInformation.properties, factory) + val reader = ComposableObjectReader( + typeInformation.typeIdentifier, + propertySerializers, + ObjectBuilder.makeProvider(typeInformation)) + + val writer = ComposableObjectWriter( + typeNotation, + typeInformation.interfaces, + propertySerializers) + + return ComposableObjectSerializer( + typeInformation.observedType, + typeDescriptor, + propertySerializers, + typeNotation.fields, + reader, + writer) + } + + private fun makePropertySerializers(properties: Map, + factory: LocalSerializerFactory): Map = + properties.mapValues { (name, property) -> + ComposableTypePropertySerializer.make(name, property, factory) + } } +} - open val propertySerializers: PropertySerializers by lazy { - propertiesForSerialization(kotlinConstructor, clazz, factory) - } +class ComposableObjectSerializer( + override val type: Type, + override val typeDescriptor: Symbol, + override val propertySerializers: Map, + override val fields: List, + private val reader: ComposableObjectReader, + private val writer: ComposableObjectWriter): ObjectSerializer { - private val typeName = nameForType(clazz) + override fun writeClassInfo(output: SerializationOutput) = writer.writeClassInfo(output) - override val typeDescriptor: Symbol = Symbol.valueOf("$DESCRIPTOR_DOMAIN:${factory.fingerPrinter.fingerprint(type)}") + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext, debugIndent: Int) = + writer.writeObject(obj, data, type, output, context, debugIndent) - // We restrict to only those annotated or whitelisted - private val interfaces = interfacesForSerialization(clazz, factory) + override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any = + reader.readObject(obj, schemas, input, context) +} - internal open val typeNotation: TypeNotation by lazy { - CompositeType(typeName, null, generateProvides(), Descriptor(typeDescriptor), generateFields()) - } - - override fun writeClassInfo(output: SerializationOutput) { +class ComposableObjectWriter( + private val typeNotation: TypeNotation, + private val interfaces: List, + private val propertySerializers: Map +) { + fun writeClassInfo(output: SerializationOutput) { if (output.writeTypeNotations(typeNotation)) { for (iface in interfaces) { - output.requireSerializer(iface) + output.requireSerializer(iface.observedType) } - propertySerializers.serializationOrder.forEach { property -> - property.serializer.writeClassInfo(output) + propertySerializers.values.forEach { serializer -> + serializer.writeClassInfo(output) } } } - override fun writeObject( - obj: Any, - data: Data, - type: Type, - output: SerializationOutput, - context: SerializationContext, - debugIndent: Int) = ifThrowsAppend({ clazz.typeName } - ) { - if (propertySerializers.deserializableSize != javaConstructor?.parameterCount && - javaConstructor?.parameterCount ?: 0 > 0 - ) { - throw AMQPNotSerializableException(type, "Serialization constructor for class $type expects " - + "${javaConstructor?.parameterCount} parameters but we have ${propertySerializers.size} " - + "properties to serialize.") - } - - // Write described + fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext, debugIndent: Int) { data.withDescribed(typeNotation.descriptor) { - // Write list withList { - propertySerializers.serializationOrder.forEach { property -> - property.serializer.writeProperty(obj, this, output, context, debugIndent + 1) + propertySerializers.values.forEach { propertySerializer -> + propertySerializer.writeProperty(obj, this, output, context, debugIndent + 1) } } } } +} - override fun readObject( - obj: Any, - schemas: SerializationSchemas, - input: DeserializationInput, - context: SerializationContext): Any = ifThrowsAppend({ clazz.typeName }) { - if (obj is List<*>) { - if (obj.size != propertySerializers.size) { - throw AMQPNotSerializableException(type, "${obj.size} objects to deserialize, but " + - "${propertySerializers.size} properties in described type $typeName") - } +class ComposableObjectReader( + val typeIdentifier: TypeIdentifier, + private val propertySerializers: Map, + private val objectBuilderProvider: () -> ObjectBuilder +) { - return if (propertySerializers.byConstructor) { - readObjectBuildViaConstructor(obj, schemas, input, context) - } else { - readObjectBuildViaSetters(obj, schemas, input, context) - } - } else { - throw AMQPNotSerializableException(type, "Body of described type is unexpected $obj") - } - } - - private fun readObjectBuildViaConstructor( - obj: List<*>, - schemas: SerializationSchemas, - input: DeserializationInput, - context: SerializationContext): Any = ifThrowsAppend({ clazz.typeName }) { - logger.trace { "Calling construction based construction for ${clazz.typeName}" } - - return construct(propertySerializers.serializationOrder - .zip(obj) - .mapNotNull { (accessor, obj) -> - // Ensure values get read out of input no matter what - val value = accessor.serializer.readProperty(obj, schemas, input, context) - - when(accessor) { - is PropertyAccessorConstructor -> accessor.initialPosition to value - is CalculatedPropertyAccessor -> null - else -> throw UnsupportedOperationException( - "${accessor::class.simpleName} accessor not supported " + - "for constructor-based object building") - } + fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any = + ifThrowsAppend({ typeIdentifier.prettyPrint(false) }) { + if (obj !is List<*>) throw NotSerializableException("Body of described type is unexpected $obj") + if (obj.size < propertySerializers.size) { + throw NotSerializableException("${obj.size} objects to deserialize, but " + + "${propertySerializers.size} properties in described type ${typeIdentifier.prettyPrint(false)}") } - .sortedWith(compareBy { it.first }) - .map { it.second }) - } - private fun readObjectBuildViaSetters( - obj: List<*>, - schemas: SerializationSchemas, - input: DeserializationInput, - context: SerializationContext): Any = ifThrowsAppend({ clazz.typeName }) { - logger.trace { "Calling setter based construction for ${clazz.typeName}" } + val builder = objectBuilderProvider() + builder.initialize() + obj.asSequence().zip(propertySerializers.values.asSequence()) + // Read _all_ properties from the stream + .map { (item, property) -> property to property.readProperty(item, schemas, input, context) } + // Throw away any calculated properties + .filter { (property, _) -> !property.isCalculated } + // Write the rest into the builder + .forEachIndexed { slot, (_, propertyValue) -> builder.populate(slot, propertyValue) } + return builder.build() + } +} - val instance: Any = javaConstructor?.newInstanceUnwrapped() ?: throw AMQPNotSerializableException( - type, - "Failed to instantiate instance of object $clazz") +class AbstractObjectSerializer( + override val type: Type, + override val typeDescriptor: Symbol, + override val propertySerializers: Map, + override val fields: List, + private val writer: ComposableObjectWriter): ObjectSerializer { + override fun writeClassInfo(output: SerializationOutput) = + writer.writeClassInfo(output) - // read the properties out of the serialised form, since we're invoking the setters the order we - // do it in doesn't matter - val propertiesFromBlob = obj - .zip(propertySerializers.serializationOrder) - .map { it.second.serializer.readProperty(it.first, schemas, input, context) } + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext, debugIndent: Int) = + writer.writeObject(obj, data, type, output, context, debugIndent) - // one by one take a property and invoke the setter on the class - propertySerializers.serializationOrder.zip(propertiesFromBlob).forEach { - it.first.set(instance, it.second) + override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any = + throw UnsupportedOperationException("Cannot deserialize abstract type ${type.typeName}") +} + +class EvolutionObjectSerializer( + override val type: Type, + override val typeDescriptor: Symbol, + override val propertySerializers: Map, + private val reader: ComposableObjectReader): ObjectSerializer { + + companion object { + fun make(localTypeInformation: LocalTypeInformation.Composable, remoteTypeInformation: RemoteTypeInformation.Composable, constructor: LocalConstructorInformation, + properties: Map, classLoader: ClassLoader): EvolutionObjectSerializer { + val propertySerializers = makePropertySerializers(properties, remoteTypeInformation.properties, classLoader) + val reader = ComposableObjectReader( + localTypeInformation.typeIdentifier, + propertySerializers, + EvolutionObjectBuilder.makeProvider(localTypeInformation.typeIdentifier, constructor, properties, remoteTypeInformation.properties.keys.sorted())) + + return EvolutionObjectSerializer( + localTypeInformation.observedType, + Symbol.valueOf(remoteTypeInformation.typeDescriptor), + propertySerializers, + reader) } - return instance + private fun makePropertySerializers(localProperties: Map, + remoteProperties: Map, + classLoader: ClassLoader): Map = + remoteProperties.mapValues { (name, property) -> + val localProperty = localProperties[name] + val isCalculated = localProperty?.isCalculated ?: false + val type = localProperty?.type?.observedType ?: property.type.typeIdentifier.getLocalType(classLoader) + ComposableTypePropertySerializer.makeForEvolution(name, isCalculated, property.type.typeIdentifier, type) + } } - private fun generateFields(): List { - return propertySerializers.serializationOrder.map { - Field(it.serializer.name, it.serializer.type, it.serializer.requires, it.serializer.default, null, it.serializer.mandatory, false) - } - } + override val fields: List get() = emptyList() - private fun generateProvides(): List = interfaces.map { nameForType(it) } + override fun writeClassInfo(output: SerializationOutput) = + throw UnsupportedOperationException("Evolved types cannot be written") - fun construct(properties: List): Any { - logger.trace { "Calling constructor: '$javaConstructor' with properties '$properties'" } + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput, context: SerializationContext, debugIndent: Int) = + throw UnsupportedOperationException("Evolved types cannot be written") - if (properties.size != javaConstructor?.parameterCount) { - throw AMQPNotSerializableException(type, "Serialization constructor for class $type expects " - + "${javaConstructor?.parameterCount} parameters but we have ${properties.size} " - + "serialized properties.") - } + override fun readObject(obj: Any, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any = + reader.readObject(obj, schemas, input, context) - return javaConstructor?.newInstanceUnwrapped(*properties.toTypedArray()) - ?: throw AMQPNotSerializableException( - type, - "Attempt to deserialize an interface: $clazz. Serialized form is invalid.") - } - - private fun Constructor.newInstanceUnwrapped(vararg args: Any?): T { - try { - return newInstance(*args) - } catch (e: InvocationTargetException) { - throw e.cause!! - } - } } \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/PropertySerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/PropertySerializer.kt deleted file mode 100644 index 3b3ee33478..0000000000 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/PropertySerializer.kt +++ /dev/null @@ -1,146 +0,0 @@ -package net.corda.serialization.internal.amqp - -import net.corda.core.KeepForDJVM -import net.corda.core.serialization.SerializationContext -import org.apache.qpid.proton.amqp.Binary -import org.apache.qpid.proton.codec.Data -import java.lang.reflect.Type - -/** - * Base class for serialization of a property of an object. - */ -sealed class PropertySerializer(val name: String, val propertyReader: PropertyReader, val resolvedType: Type) { - abstract fun writeClassInfo(output: SerializationOutput) - abstract fun writeProperty(obj: Any?, data: Data, output: SerializationOutput, context: SerializationContext, debugIndent: Int = 0) - abstract fun readProperty(obj: Any?, schemas: SerializationSchemas, input: DeserializationInput, context: SerializationContext): Any? - - val type: String = generateType() - val requires: List = generateRequires() - val default: String? = generateDefault() - val mandatory: Boolean = generateMandatory() - - private val isInterface: Boolean get() = resolvedType.asClass().isInterface - private val isJVMPrimitive: Boolean get() = resolvedType.asClass().isPrimitive - - private fun generateType(): String { - return if (isInterface || resolvedType == Any::class.java) "*" else SerializerFactory.nameForType(resolvedType) - } - - private fun generateRequires(): List { - return if (isInterface) listOf(SerializerFactory.nameForType(resolvedType)) else emptyList() - } - - private fun generateDefault(): String? = - if (isJVMPrimitive) { - when (resolvedType) { - java.lang.Boolean.TYPE -> "false" - java.lang.Character.TYPE -> "�" - else -> "0" - } - } else { - null - } - - private fun generateMandatory(): Boolean { - return isJVMPrimitive || !(propertyReader.isNullable()) - } - - companion object { - fun make(name: String, readMethod: PropertyReader, resolvedType: Type, factory: SerializerFactory): PropertySerializer { - return if (SerializerFactory.isPrimitive(resolvedType)) { - when (resolvedType) { - Char::class.java, Character::class.java -> AMQPCharPropertySerializer(name, readMethod) - else -> AMQPPrimitivePropertySerializer(name, readMethod, resolvedType) - } - } else { - DescribedTypePropertySerializer(name, readMethod, resolvedType) { factory.get(null, resolvedType) } - } - } - } - - /** - * A property serializer for a complex type (another object). - */ - @KeepForDJVM - class DescribedTypePropertySerializer( - name: String, - readMethod: PropertyReader, - resolvedType: Type, - private val lazyTypeSerializer: () -> AMQPSerializer<*>) : PropertySerializer(name, readMethod, resolvedType) { - // This is lazy so we don't get an infinite loop when a method returns an instance of the class. - private val typeSerializer: AMQPSerializer<*> by lazy { lazyTypeSerializer() } - - override fun writeClassInfo(output: SerializationOutput) = ifThrowsAppend({ nameForDebug }) { - if (resolvedType != Any::class.java) { - typeSerializer.writeClassInfo(output) - } - } - - override fun readProperty( - obj: Any?, - schemas: SerializationSchemas, - input: DeserializationInput, - context: SerializationContext): Any? = ifThrowsAppend({ nameForDebug }) { - input.readObjectOrNull(obj, schemas, resolvedType, context) - } - - override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput, - context: SerializationContext, debugIndent: Int) = ifThrowsAppend({ nameForDebug } - ) { - output.writeObjectOrNull(propertyReader.read(obj), data, resolvedType, context, debugIndent) - } - - private val nameForDebug = "$name(${resolvedType.typeName})" - } - - /** - * A property serializer for most AMQP primitive type (Int, String, etc). - */ - class AMQPPrimitivePropertySerializer( - name: String, - readMethod: PropertyReader, - resolvedType: Type) : PropertySerializer(name, readMethod, resolvedType) { - override fun writeClassInfo(output: SerializationOutput) {} - - override fun readProperty(obj: Any?, schemas: SerializationSchemas, - input: DeserializationInput, context: SerializationContext - ): Any? { - return if (obj is Binary) obj.array else obj - } - - override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput, - context: SerializationContext, debugIndent: Int - ) { - val value = propertyReader.read(obj) - if (value is ByteArray) { - data.putObject(Binary(value)) - } else { - data.putObject(value) - } - } - } - - /** - * A property serializer for the AMQP char type, needed as a specialisation as the underlying - * value of the character is stored in numeric UTF-16 form and on deserialization requires explicit - * casting back to a char otherwise it's treated as an Integer and a TypeMismatch occurs - */ - class AMQPCharPropertySerializer(name: String, readMethod: PropertyReader) : - PropertySerializer(name, readMethod, Character::class.java) { - override fun writeClassInfo(output: SerializationOutput) {} - - override fun readProperty(obj: Any?, schemas: SerializationSchemas, - input: DeserializationInput, context: SerializationContext - ): Any? { - return if (obj == null) null else (obj as Short).toChar() - } - - override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput, - context: SerializationContext, debugIndent: Int - ) { - val input = propertyReader.read(obj) - if (input != null) data.putShort((input as Char).toShort()) else data.putNull() - } - } -} - diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/PropertySerializers.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/PropertySerializers.kt deleted file mode 100644 index 517af0406a..0000000000 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/PropertySerializers.kt +++ /dev/null @@ -1,243 +0,0 @@ -package net.corda.serialization.internal.amqp - -import net.corda.core.KeepForDJVM -import net.corda.core.serialization.SerializableCalculatedProperty -import net.corda.core.utilities.loggerFor -import java.io.NotSerializableException -import java.lang.reflect.Field -import java.lang.reflect.Method -import java.lang.reflect.Type -import kotlin.reflect.full.memberProperties -import kotlin.reflect.jvm.javaGetter -import kotlin.reflect.jvm.kotlinProperty - -abstract class PropertyReader { - abstract fun read(obj: Any?): Any? - abstract fun isNullable(): Boolean -} - -/** - * Accessor for those properties of a class that have defined getter functions. - */ -@KeepForDJVM -class PublicPropertyReader(private val readMethod: Method) : PropertyReader() { - init { - readMethod.isAccessible = true - } - - private fun Method.returnsNullable(): Boolean { - try { - val returnTypeString = this.declaringClass.kotlin.memberProperties.firstOrNull { - it.javaGetter == this - }?.returnType?.toString() ?: "?" - - return returnTypeString.endsWith('?') || returnTypeString.endsWith('!') - } catch (e: kotlin.reflect.jvm.internal.KotlinReflectionInternalError) { - // This might happen for some types, e.g. kotlin.Throwable? - the root cause of the issue - // is: https://youtrack.jetbrains.com/issue/KT-13077 - // TODO: Revisit this when Kotlin issue is fixed. - - // So this used to report as an error, but given we serialise exceptions all the time it - // provides for very scary log files so move this to trace level - loggerFor().let { logger -> - logger.trace("Using kotlin introspection on internal type ${this.declaringClass}") - logger.trace("Unexpected internal Kotlin error", e) - } - return true - } - } - - override fun read(obj: Any?): Any? { - return readMethod.invoke(obj) - } - - override fun isNullable(): Boolean = readMethod.returnsNullable() - - val genericReturnType get() = readMethod.genericReturnType -} - -/** - * Accessor for those properties of a class that do not have defined getter functions. In which case - * we used reflection to remove the unreadable status from that property whilst it's accessed. - */ -@KeepForDJVM -class PrivatePropertyReader(val field: Field, parentType: Type) : PropertyReader() { - init { - loggerFor().warn("Create property Serializer for private property '${field.name}' not " - + "exposed by a getter on class '$parentType'\n" - + "\tNOTE: This behaviour will be deprecated at some point in the future and a getter required") - } - - override fun read(obj: Any?): Any? { - field.isAccessible = true - val rtn = field.get(obj) - field.isAccessible = false - return rtn - } - - override fun isNullable() = try { - field.kotlinProperty?.returnType?.isMarkedNullable ?: false - } catch (e: kotlin.reflect.jvm.internal.KotlinReflectionInternalError) { - // This might happen for some types, e.g. kotlin.Throwable? - the root cause of the issue - // is: https://youtrack.jetbrains.com/issue/KT-13077 - // TODO: Revisit this when Kotlin issue is fixed. - - // So this used to report as an error, but given we serialise exceptions all the time it - // provides for very scary log files so move this to trace level - loggerFor().let { logger -> - logger.trace("Using kotlin introspection on internal type $field") - logger.trace("Unexpected internal Kotlin error", e) - } - true - } -} - -/** - * Special instance of a [PropertyReader] for use only by [EvolutionSerializer]s to make - * it explicit that no properties are ever actually read from an object as the evolution - * serializer should only be accessing the already serialized form. - */ -class EvolutionPropertyReader : PropertyReader() { - override fun read(obj: Any?): Any? { - throw UnsupportedOperationException("It should be impossible for an evolution serializer to " - + "be reading from an object") - } - - override fun isNullable() = true -} - -/** - * Represents a generic interface to a serializable property of an object. - * - * @property initialPosition where in the constructor used for serialization the property occurs. - * @property serializer a [PropertySerializer] wrapping access to the property. This will either be a - * method invocation on the getter or, if not publicly accessible, reflection based by temporally - * making the property accessible. - */ -abstract class PropertyAccessor( - open val serializer: PropertySerializer) { - companion object : Comparator { - override fun compare(p0: PropertyAccessor?, p1: PropertyAccessor?): Int { - return p0?.serializer?.name?.compareTo(p1?.serializer?.name ?: "") ?: 0 - } - } - - open val isCalculated get() = false - - /** - * Override to control how the property is set on the object. - */ - abstract fun set(instance: Any, obj: Any?) - - override fun toString(): String { - return serializer.name - } -} - -/** - * Implementation of [PropertyAccessor] representing a property of an object that - * is serialized and deserialized via JavaBean getter and setter style methods. - */ -class PropertyAccessorGetterSetter( - getter: PropertySerializer, - private val setter: Method) : PropertyAccessor(getter) { - init { - /** - * Play nicely with Java interop, public methods aren't marked as accessible - */ - setter.isAccessible = true - } - - /** - * Invokes the setter on the underlying object passing in the serialized value. - */ - override fun set(instance: Any, obj: Any?) { - setter.invoke(instance, *listOf(obj).toTypedArray()) - } -} - -/** - * Implementation of [PropertyAccessor] representing a property of an object that - * is serialized via a JavaBean getter but deserialized using the constructor - * of the object the property belongs to. - */ -class PropertyAccessorConstructor( - val initialPosition: Int, - override val serializer: PropertySerializer) : PropertyAccessor(serializer) { - /** - * Because the property should be being set on the object through the constructor any - * calls to the explicit setter should be an error. - */ - override fun set(instance: Any, obj: Any?) { - NotSerializableException("Attempting to access a setter on an object being instantiated " + - "via its constructor.") - } - - override fun toString(): String = - "${serializer.name}($initialPosition)" -} - -/** - * Implementation of [PropertyAccessor] representing a calculated property of an object that is serialized - * so that it can be used by the class carpenter, but ignored on deserialisation as there is no setter or - * constructor parameter to receive its value. - * - * This will only be created for calculated properties that are accessible via no-argument methods annotated - * with [SerializableCalculatedProperty]. - */ -class CalculatedPropertyAccessor(override val serializer: PropertySerializer): PropertyAccessor(serializer) { - override val isCalculated: Boolean - get() = true - - override fun set(instance: Any, obj: Any?) = Unit // do nothing, as it's a calculated value -} - -/** - * Represents a collection of [PropertyAccessor]s that represent the serialized form - * of an object. - * - * @property serializationOrder a list of [PropertyAccessor]. For deterministic serialization - * should be sorted. - * @property size how many properties are being serialized. - * @property byConstructor are the properties of the class represented by this set of properties populated - * on deserialization via the object's constructor or the corresponding setter functions. Should be - * overridden and set appropriately by child types. - */ -abstract class PropertySerializers( - val serializationOrder: List) { - companion object { - fun make(serializationOrder: List) = - when (serializationOrder.find { !it.isCalculated }) { - is PropertyAccessorConstructor -> PropertySerializersConstructor(serializationOrder) - is PropertyAccessorGetterSetter -> PropertySerializersSetter(serializationOrder) - null -> PropertySerializersNoProperties() - else -> { - throw AMQPNoTypeNotSerializableException("Unknown Property Accessor type, cannot create set") - } - } - } - - val size get() = serializationOrder.size - abstract val byConstructor: Boolean - val deserializableSize = serializationOrder.count { !it.isCalculated } -} - -class PropertySerializersNoProperties : PropertySerializers(emptyList()) { - override val byConstructor get() = true -} - -class PropertySerializersConstructor( - serializationOrder: List) : PropertySerializers(serializationOrder) { - override val byConstructor get() = true -} - -class PropertySerializersSetter( - serializationOrder: List) : PropertySerializers(serializationOrder) { - override val byConstructor get() = false -} - -class PropertySerializersEvolution : PropertySerializers(emptyList()) { - override val byConstructor get() = false -} - - diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/RemoteSerializerFactory.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/RemoteSerializerFactory.kt new file mode 100644 index 0000000000..c92947651b --- /dev/null +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/RemoteSerializerFactory.kt @@ -0,0 +1,141 @@ +package net.corda.serialization.internal.amqp + +import net.corda.core.utilities.contextLogger +import net.corda.serialization.internal.model.* +import org.hibernate.type.descriptor.java.ByteTypeDescriptor +import java.io.NotSerializableException + +/** + * A factory that knows how to create serializers to deserialize values sent to us by remote parties. + */ +interface RemoteSerializerFactory { + /** + * Lookup and manufacture a serializer for the given AMQP type descriptor, assuming we also have the necessary types + * contained in the provided [Schema]. + * + * @param typeDescriptor The type descriptor for the type to obtain a serializer for. + * @param schema The schemas sent along with the serialized data. + */ + @Throws(NotSerializableException::class) + fun get(typeDescriptor: TypeDescriptor, schema: SerializationSchemas): AMQPSerializer +} + +/** + * Represents the reflection of some [RemoteTypeInformation] by some [LocalTypeInformation], which we use to make + * decisions about evolution. + */ +data class RemoteAndLocalTypeInformation( + val remoteTypeInformation: RemoteTypeInformation, + val localTypeInformation: LocalTypeInformation) + +/** + * A [RemoteSerializerFactory] which uses an [AMQPRemoteTypeModel] to interpret AMQP [Schema]s into [RemoteTypeInformation], + * reflects this into [LocalTypeInformation] using a [LocalTypeModel] and a [TypeLoader], and compares the two in order to + * decide whether to return the serializer provided by the [LocalSerializerFactory] or to construct a special evolution serializer + * using the [EvolutionSerializerFactory]. + * + * Its decisions are recorded by registering the chosen serialisers against their type descriptors + * in the [DescriptorBasedSerializerRegistry]. + * + * @param evolutionSerializerFactory The [EvolutionSerializerFactory] to use to create evolution serializers, when necessary. + * @param descriptorBasedSerializerRegistry The registry to use to store serializers by [TypeDescriptor]. + * @param remoteTypeModel The [AMQPRemoteTypeModel] to use to interpret AMPQ [Schema] information into [RemoteTypeInformation]. + * @param localTypeModel The [LocalTypeModel] to use to obtain [LocalTypeInformation] for reflected [Type]s. + * @param typeLoader The [TypeLoader] to use to load local [Type]s reflecting [RemoteTypeInformation]. + * @param localSerializerFactory The [LocalSerializerFactory] to use to obtain serializers for non-evolved types. + */ +class DefaultRemoteSerializerFactory( + private val evolutionSerializerFactory: EvolutionSerializerFactory, + private val descriptorBasedSerializerRegistry: DescriptorBasedSerializerRegistry, + private val remoteTypeModel: AMQPRemoteTypeModel, + private val localTypeModel: LocalTypeModel, + private val typeLoader: TypeLoader, + private val localSerializerFactory: LocalSerializerFactory) + : RemoteSerializerFactory { + + companion object { + private val logger = contextLogger() + } + + override fun get(typeDescriptor: TypeDescriptor, schema: SerializationSchemas): AMQPSerializer = + // If we have seen this descriptor before, we assume we have seen everything in this schema before. + descriptorBasedSerializerRegistry.getOrBuild(typeDescriptor) { + logger.trace("get Serializer descriptor=$typeDescriptor") + + // Interpret all of the types in the schema into RemoteTypeInformation, and reflect that into LocalTypeInformation. + val remoteTypeInformationMap = remoteTypeModel.interpret(schema) + val reflected = reflect(remoteTypeInformationMap) + + // Get, and record in the registry, serializers for all of the types contained in the schema. + // This will save us having to re-interpret the entire schema on re-entry when deserialising individual property values. + val serializers = reflected.mapValues { (descriptor, remoteLocalPair) -> + descriptorBasedSerializerRegistry.getOrBuild(descriptor) { + getUncached(remoteLocalPair.remoteTypeInformation, remoteLocalPair.localTypeInformation) + } + } + + // Return the specific serializer the caller asked for. + serializers[typeDescriptor] ?: throw NotSerializableException( + "Could not find type matching descriptor $typeDescriptor.") + } + + private fun getUncached(remoteTypeInformation: RemoteTypeInformation, localTypeInformation: LocalTypeInformation): AMQPSerializer { + val remoteDescriptor = remoteTypeInformation.typeDescriptor + + // Obtain a serializer and descriptor for the local type. + val localSerializer = localSerializerFactory.get(localTypeInformation) + val localDescriptor = localSerializer.typeDescriptor.toString() + + return when { + // If descriptors match, we can return the local serializer straight away. + localDescriptor == remoteDescriptor -> localSerializer + + // Can we deserialise without evolution, e.g. going from List to List<*>? + remoteTypeInformation.isDeserialisableWithoutEvolutionTo(localTypeInformation) -> localSerializer + + // Are the remote/local types evolvable? If so, ask the evolution serializer factory for a serializer, returning + // the local serializer if it returns null (i.e. no evolution required). + remoteTypeInformation.isEvolvableTo(localTypeInformation) -> + evolutionSerializerFactory.getEvolutionSerializer(remoteTypeInformation, localTypeInformation) + ?: localSerializer + + // Descriptors don't match, and something is probably broken, but we let the framework do what it can with the local + // serialiser (BlobInspectorTest uniquely breaks if we throw an exception here, and passes if we just warn and continue). + else -> { + logger.warn(""" +Mismatch between type descriptors, but remote type is not evolvable to local type. + +Remote type (descriptor: $remoteDescriptor) +${remoteTypeInformation.prettyPrint(false)} + +Local type (descriptor $localDescriptor): +${localTypeInformation.prettyPrint(false)} + """) + + localSerializer + } + } + } + + private fun reflect(remoteInformation: Map): + Map { + val localInformationByIdentifier = typeLoader.load(remoteInformation.values).mapValues { (_, type) -> + localTypeModel.inspect(type) + } + + return remoteInformation.mapValues { (_, remoteInformation) -> + RemoteAndLocalTypeInformation(remoteInformation, localInformationByIdentifier[remoteInformation.typeIdentifier]!!) + } + } + + private fun RemoteTypeInformation.isEvolvableTo(localTypeInformation: LocalTypeInformation): Boolean = when(this) { + is RemoteTypeInformation.Composable -> localTypeInformation is LocalTypeInformation.Composable + is RemoteTypeInformation.AnEnum -> localTypeInformation is LocalTypeInformation.AnEnum + else -> false + } + + private fun RemoteTypeInformation.isDeserialisableWithoutEvolutionTo(localTypeInformation: LocalTypeInformation) = + this is RemoteTypeInformation.Parameterised && + (localTypeInformation is LocalTypeInformation.ACollection || + localTypeInformation is LocalTypeInformation.AMap) +} \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationHelper.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationHelper.kt index aed85b9825..3a7014615c 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationHelper.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationHelper.kt @@ -2,243 +2,10 @@ package net.corda.serialization.internal.amqp import com.google.common.primitives.Primitives import com.google.common.reflect.TypeToken -import net.corda.core.internal.isConcreteClass import net.corda.core.serialization.* +import net.corda.serialization.internal.model.TypeIdentifier import org.apache.qpid.proton.codec.Data import java.lang.reflect.* -import java.lang.reflect.Field -import java.util.* -import kotlin.reflect.KClass -import kotlin.reflect.KFunction -import kotlin.reflect.KParameter -import kotlin.reflect.full.findAnnotation -import kotlin.reflect.full.primaryConstructor -import kotlin.reflect.jvm.isAccessible -import kotlin.reflect.jvm.javaConstructor -import kotlin.reflect.jvm.javaType - -/** - * Code for finding the constructor we will use for deserialization. - * - * If any constructor is uniquely annotated with [@ConstructorForDeserialization], then that constructor is chosen. - * An error is reported if more than one constructor is annotated. - * - * Otherwise, if there is a Kotlin primary constructor, it selects that, and if not it selects either the unique - * constructor or, if there are two and one is the default no-argument constructor, the non-default constructor. - */ -fun constructorForDeserialization(type: Type): KFunction { - val clazz = type.asClass().apply { - if (!isConcreteClass) throw AMQPNotSerializableException(type, - "Cannot find deserialisation constructor for non-concrete class $this") - } - - val kotlinCtors = clazz.kotlin.constructors - - val annotatedCtors = kotlinCtors.filter { it.findAnnotation() != null } - if (annotatedCtors.size > 1) throw AMQPNotSerializableException( - type, - "More than one constructor for $clazz is annotated with @ConstructorForDeserialization.") - - val defaultCtor = kotlinCtors.firstOrNull { it.parameters.isEmpty() } - val nonDefaultCtors = kotlinCtors.filter { it != defaultCtor } - - val preferredCandidate = annotatedCtors.firstOrNull() ?: - clazz.kotlin.primaryConstructor ?: - when(nonDefaultCtors.size) { - 1 -> nonDefaultCtors.first() - 0 -> defaultCtor ?: throw AMQPNotSerializableException(type, "No constructor found for $clazz.") - else -> throw AMQPNotSerializableException(type, "No unique non-default constructor found for $clazz.") - } - - return preferredCandidate.apply { isAccessible = true } -} - -/** - * Identifies the properties to be used during serialization by attempting to find those that match the parameters - * to the deserialization constructor, if the class is concrete. If it is abstract, or an interface, then use all - * the properties. - * - * Note, you will need any Java classes to be compiled with the `-parameters` option to ensure constructor parameters - * have names accessible via reflection. - */ -fun propertiesForSerialization( - kotlinConstructor: KFunction?, - type: Type, - factory: SerializerFactory): PropertySerializers = PropertySerializers.make( - getValueProperties(kotlinConstructor, type, factory) - .addCalculatedProperties(factory, type) - .sortedWith(PropertyAccessor)) - -fun getValueProperties(kotlinConstructor: KFunction?, type: Type, factory: SerializerFactory) - : List = - if (kotlinConstructor != null) { - propertiesForSerializationFromConstructor(kotlinConstructor, type, factory) - } else { - propertiesForSerializationFromAbstract(type.asClass(), type, factory) - } - -private fun List.addCalculatedProperties(factory: SerializerFactory, type: Type) - : List { - val nonCalculated = map { it.serializer.name }.toSet() - return this + type.asClass().calculatedPropertyDescriptors().mapNotNull { (name, descriptor) -> - if (name in nonCalculated) null else { - val calculatedPropertyMethod = descriptor.getter - ?: throw IllegalStateException("Property $name is not a calculated property") - CalculatedPropertyAccessor(PropertySerializer.make( - name, - PublicPropertyReader(calculatedPropertyMethod), - calculatedPropertyMethod.genericReturnType, - factory)) - } - } -} - -/** - * From a constructor, determine which properties of a class are to be serialized. - * - * @param kotlinConstructor The constructor to be used to instantiate instances of the class - * @param type The class's [Type] - * @param factory The factory generating the serializer wrapping this function. - */ -internal fun propertiesForSerializationFromConstructor( - kotlinConstructor: KFunction, - type: Type, - factory: SerializerFactory): List { - val clazz = (kotlinConstructor.returnType.classifier as KClass<*>).javaObjectType - - val classProperties = clazz.propertyDescriptors() - - // Annoyingly there isn't a better way to ascertain that the constructor for the class - // has a synthetic parameter inserted to capture the reference to the outer class. You'd - // think you could inspect the parameter and check the isSynthetic flag but that is always - // false so given the naming convention is specified by the standard we can just check for - // this - kotlinConstructor.javaConstructor?.apply { - if (parameterCount > 0 && parameters[0].name == "this$0") throw SyntheticParameterException(type) - } - - if (classProperties.isNotEmpty() && kotlinConstructor.parameters.isEmpty()) { - return propertiesForSerializationFromSetters(classProperties, type, factory) - } - - return kotlinConstructor.parameters.withIndex().map { param -> - toPropertyAccessorConstructor(param.index, param.value, classProperties, type, clazz, factory) - } -} - -private fun toPropertyAccessorConstructor(index: Int, param: KParameter, classProperties: Map, type: Type, clazz: Class, factory: SerializerFactory): PropertyAccessorConstructor { - // name cannot be null, if it is then this is a synthetic field and we will have bailed - // out prior to this - val name = param.name!! - - // We will already have disambiguated getA for property A or a but we still need to cope - // with the case we don't know the case of A when the parameter doesn't match a property - // but has a getter - val matchingProperty = classProperties[name] ?: classProperties[name.capitalize()] - ?: throw AMQPNotSerializableException(type, - "Constructor parameter - \"$name\" - doesn't refer to a property of \"$clazz\"") - - // If the property has a getter we'll use that to retrieve it's value from the instance, if it doesn't - // *for *now* we switch to a reflection based method - val propertyReader = matchingProperty.getter?.let { getter -> - getPublicPropertyReader(getter, type, param, name, clazz) - } ?: matchingProperty.field?.let { field -> - getPrivatePropertyReader(field, type) - } ?: throw AMQPNotSerializableException(type, - "No property matching constructor parameter named - \"$name\" - " + - "of \"${param}\". If using Java, check that you have the -parameters option specified " + - "in the Java compiler. Alternately, provide a proxy serializer " + - "(SerializationCustomSerializer) if recompiling isn't an option") - - return PropertyAccessorConstructor( - index, - PropertySerializer.make(name, propertyReader.first, propertyReader.second, factory)) -} - -/** - * If we determine a class has a constructor that takes no parameters then check for pairs of getters / setters - * and use those - */ -fun propertiesForSerializationFromSetters( - properties: Map, - type: Type, - factory: SerializerFactory): List = - properties.asSequence().map { entry -> - val (name, property) = entry - - val getter = property.getter - val setter = property.setter - - if (getter == null || setter == null) return@map null - - PropertyAccessorGetterSetter( - PropertySerializer.make( - name, - PublicPropertyReader(getter), - resolveTypeVariables(getter.genericReturnType, type), - factory), - setter) - }.filterNotNull().toList() - -private fun getPrivatePropertyReader(field: Field, type: Type) = - PrivatePropertyReader(field, type) to resolveTypeVariables(field.genericType, type) - -private fun getPublicPropertyReader(getter: Method, type: Type, param: KParameter, name: String, clazz: Class): Pair { - val returnType = resolveTypeVariables(getter.genericReturnType, type) - val paramToken = TypeToken.of(param.type.javaType) - val rawParamType = TypeToken.of(paramToken.rawType) - - if (!(paramToken.isSupertypeOf(returnType) - || paramToken.isSupertypeOf(getter.genericReturnType) - // cope with the case where the constructor parameter is a generic type (T etc) but we - // can discover it's raw type. When bounded this wil be the bounding type, unbounded - // generics this will be object - || rawParamType.isSupertypeOf(returnType) - || rawParamType.isSupertypeOf(getter.genericReturnType))) { - throw AMQPNotSerializableException( - type, - "Property - \"$name\" - has type \"$returnType\" on \"$clazz\" " + - "but differs from constructor parameter type \"${param.type.javaType}\"") - } - - return PublicPropertyReader(getter) to returnType -} - -private fun propertiesForSerializationFromAbstract( - clazz: Class<*>, - type: Type, - factory: SerializerFactory): List = - clazz.propertyDescriptors().asSequence().withIndex().mapNotNull { (index, entry) -> - val (name, property) = entry - if (property.getter == null || property.field == null) return@mapNotNull null - - val getter = property.getter - val returnType = resolveTypeVariables(getter.genericReturnType, type) - - PropertyAccessorConstructor( - index, - PropertySerializer.make(name, PublicPropertyReader(getter), returnType, factory)) - }.toList() - -internal fun interfacesForSerialization(type: Type, serializerFactory: SerializerFactory): List = - exploreType(type, serializerFactory).toList() - -private fun exploreType(type: Type, serializerFactory: SerializerFactory, interfaces: MutableSet = LinkedHashSet()): MutableSet { - val clazz = type.asClass() - - if (clazz.isInterface) { - // Ignore classes we've already seen, and stop exploring once we reach a branch that has no `CordaSerializable` - // annotation or whitelisting. - if (clazz in interfaces || serializerFactory.whitelist.isNotWhitelisted(clazz)) return interfaces - else interfaces += type - } - - (clazz.genericInterfaces.asSequence() + clazz.genericSuperclass) - .filterNotNull() - .forEach { exploreType(resolveTypeVariables(it, type), serializerFactory, interfaces) } - - return interfaces -} /** * Extension helper for writing described objects. @@ -283,7 +50,7 @@ fun resolveTypeVariables(actualType: Type, contextType: Type?): Type { return if (resolvedType is TypeVariable<*>) { val bounds = resolvedType.bounds return if (bounds.isEmpty()) { - SerializerFactory.AnyType + TypeIdentifier.UnknownType.getLocalType() } else if (bounds.size == 1) { resolveTypeVariables(bounds[0], contextType) } else throw AMQPNotSerializableException( @@ -309,8 +76,9 @@ internal fun Type.asClass(): Class<*> { internal fun Type.asArray(): Type? { return when(this) { - is Class<*> -> this.arrayClass() - is ParameterizedType -> DeserializedGenericArrayType(this) + is Class<*>, + is ParameterizedType -> TypeIdentifier.ArrayOf(TypeIdentifier.forGenericType(this)) + .getLocalType(this::class.java.classLoader ?: TypeIdentifier::class.java.classLoader) else -> null } } @@ -324,9 +92,10 @@ internal fun Type.componentType(): Type { return (this as? Class<*>)?.componentType ?: (this as GenericArrayType).genericComponentType } -internal fun Class<*>.asParameterizedType(): ParameterizedType { - return DeserializedParameterizedType(this, this.typeParameters) -} +internal fun Class<*>.asParameterizedType(): ParameterizedType = + TypeIdentifier.Erased(this.name, this.typeParameters.size) + .toParameterized(this.typeParameters.map { TypeIdentifier.forGenericType(it) }) + .getLocalType(classLoader ?: TypeIdentifier::class.java.classLoader) as ParameterizedType internal fun Type.asParameterizedType(): ParameterizedType { return when (this) { @@ -374,19 +143,4 @@ fun hasCordaSerializable(type: Class<*>): Boolean { return type.isAnnotationPresent(CordaSerializable::class.java) || type.interfaces.any(::hasCordaSerializable) || (type.superclass != null && hasCordaSerializable(type.superclass)) -} - -fun isJavaPrimitive(type: Class<*>) = type in JavaPrimitiveTypes.primativeTypes - -private object JavaPrimitiveTypes { - val primativeTypes = hashSetOf>( - Boolean::class.java, - Char::class.java, - Byte::class.java, - Short::class.java, - Int::class.java, - Long::class.java, - Float::class.java, - Double::class.java, - Void::class.java) -} +} \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationOutput.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationOutput.kt index d24e5ea77b..1ba283203f 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationOutput.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializationOutput.kt @@ -7,10 +7,12 @@ import net.corda.core.utilities.contextLogger import net.corda.serialization.internal.CordaSerializationEncoding import net.corda.serialization.internal.SectionId import net.corda.serialization.internal.byteArrayOutput +import net.corda.serialization.internal.model.TypeIdentifier import org.apache.qpid.proton.codec.Data import java.io.NotSerializableException import java.io.OutputStream import java.lang.reflect.Type +import java.lang.reflect.WildcardType import java.util.* import kotlin.collections.LinkedHashSet @@ -28,7 +30,7 @@ data class BytesAndSchemas( */ @KeepForDJVM open class SerializationOutput constructor( - internal val serializerFactory: SerializerFactory + internal val serializerFactory: LocalSerializerFactory ) { companion object { private val logger = contextLogger() @@ -118,7 +120,7 @@ open class SerializationOutput constructor( if (obj == null) { data.putNull() } else { - writeObject(obj, data, if (type == SerializerFactory.AnyType) obj.javaClass else type, context, debugIndent) + writeObject(obj, data, if (type == TypeIdentifier.UnknownType.getLocalType()) obj.javaClass else type, context, debugIndent) } } @@ -148,8 +150,15 @@ open class SerializationOutput constructor( } internal open fun requireSerializer(type: Type) { - if (type != SerializerFactory.AnyType && type != Object::class.java) { - val serializer = serializerFactory.get(null, type) + if (type != Object::class.java && type.typeName != "?") { + val resolvedType = when(type) { + is WildcardType -> + if (type.upperBounds.size == 1) type.upperBounds[0] + else throw NotSerializableException("Cannot obtain upper bound for type $type") + else -> type + } + + val serializer = serializerFactory.get(resolvedType) if (serializer !in serializerHistory) { serializerHistory.add(serializer) serializer.writeClassInfo(this) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactory.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactory.kt index 8e2e1e1330..42d0f1dda9 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactory.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactory.kt @@ -1,29 +1,11 @@ package net.corda.serialization.internal.amqp -import com.google.common.primitives.Primitives import net.corda.core.KeepForDJVM -import net.corda.core.StubOutForDJVM -import net.corda.core.internal.kotlinObjectInstance -import net.corda.core.internal.uncheckedCast -import net.corda.core.serialization.ClassWhitelist -import net.corda.core.utilities.contextLogger -import net.corda.core.utilities.debug -import net.corda.core.utilities.loggerFor -import net.corda.core.utilities.trace -import net.corda.serialization.internal.carpenter.* -import net.corda.serialization.internal.model.DefaultCacheProvider -import org.apache.qpid.proton.amqp.* import java.io.NotSerializableException -import java.lang.reflect.* -import java.util.* import javax.annotation.concurrent.ThreadSafe @KeepForDJVM data class SerializationSchemas(val schema: Schema, val transforms: TransformsSchema) -@KeepForDJVM -data class FactorySchemaAndDescriptor(val schemas: SerializationSchemas, val typeDescriptor: Any) -@KeepForDJVM -data class CustomSerializersCacheKey(val clazz: Class<*>, val declaredType: Type) /** * Factory of serializers designed to be shared across threads and invocations. @@ -34,426 +16,15 @@ data class CustomSerializersCacheKey(val clazz: Class<*>, val declaredType: Type * @property onlyCustomSerializers used for testing, when set will cause the factory to throw a * [NotSerializableException] if it cannot find a registered custom serializer for a given type */ -// TODO: support for intern-ing of deserialized objects for some core types (e.g. PublicKey) for memory efficiency -// TODO: maybe support for caching of serialized form of some core types for performance -// TODO: profile for performance in general -// TODO: use guava caches etc so not unbounded -// TODO: allow definition of well known types that are left out of the schema. -// TODO: migrate some core types to unsigned integer descriptor -// TODO: document and alert to the fact that classes cannot default superclass/interface properties otherwise they are "erased" due to matching with constructor. -// TODO: type name prefixes for interfaces and abstract classes? Or use label? -// TODO: generic types should define restricted type alias with source of the wildcarded version, I think, if we're to generate classes from schema -// TODO: need to rethink matching of constructor to properties in relation to implementing interfaces and needing those properties etc. -// TODO: need to support super classes as well as interfaces with our current code base... what's involved? If we continue to ban, what is the impact? @KeepForDJVM @ThreadSafe -interface SerializerFactory { - val whitelist: ClassWhitelist - val classCarpenter: ClassCarpenter - val fingerPrinterConstructor: (SerializerFactory) -> FingerPrinter - // Caches - val serializersByType: MutableMap> - val serializersByDescriptor: MutableMap> - val transformsCache: MutableMap>> - val fingerPrinter: FingerPrinter - val classloader: ClassLoader - /** - * Look up, and manufacture if necessary, a serializer for the given type. - * - * @param actualClass Will be null if there isn't an actual object instance available (e.g. for - * restricted type processing). - */ - @Throws(NotSerializableException::class) - fun get(actualClass: Class<*>?, declaredType: Type): AMQPSerializer +interface SerializerFactory : LocalSerializerFactory, RemoteSerializerFactory, CustomSerializerRegistry - /** - * Lookup and manufacture a serializer for the given AMQP type descriptor, assuming we also have the necessary types - * contained in the [Schema]. - */ - @Throws(NotSerializableException::class) - fun get(typeDescriptor: Any, schema: SerializationSchemas): AMQPSerializer - - /** - * Register a custom serializer for any type that cannot be serialized or deserialized by the default serializer - * that expects to find getters and a constructor with a parameter for each property. - */ - fun register(customSerializer: CustomSerializer) - - fun findCustomSerializer(clazz: Class<*>, declaredType: Type): AMQPSerializer? - fun registerExternal(customSerializer: CorDappCustomSerializer) - fun registerByDescriptor(name: Symbol, serializerCreator: () -> AMQPSerializer): AMQPSerializer - - object AnyType : WildcardType { - override fun getUpperBounds(): Array = arrayOf(Object::class.java) - - override fun getLowerBounds(): Array = emptyArray() - - override fun toString(): String = "?" - } - - companion object { - fun isPrimitive(type: Type): Boolean = primitiveTypeName(type) != null - - fun primitiveTypeName(type: Type): String? { - val clazz = type as? Class<*> ?: return null - return primitiveTypeNames[Primitives.unwrap(clazz)] - } - - fun primitiveType(type: String): Class<*>? { - return namesOfPrimitiveTypes[type] - } - - private val primitiveTypeNames: Map, String> = mapOf( - Character::class.java to "char", - Char::class.java to "char", - Boolean::class.java to "boolean", - Byte::class.java to "byte", - UnsignedByte::class.java to "ubyte", - Short::class.java to "short", - UnsignedShort::class.java to "ushort", - Int::class.java to "int", - UnsignedInteger::class.java to "uint", - Long::class.java to "long", - UnsignedLong::class.java to "ulong", - Float::class.java to "float", - Double::class.java to "double", - Decimal32::class.java to "decimal32", - Decimal64::class.java to "decimal64", - Decimal128::class.java to "decimal128", - Date::class.java to "timestamp", - UUID::class.java to "uuid", - ByteArray::class.java to "binary", - String::class.java to "string", - Symbol::class.java to "symbol") - - private val namesOfPrimitiveTypes: Map> = primitiveTypeNames.map { it.value to it.key }.toMap() - - fun nameForType(type: Type): String = when (type) { - is Class<*> -> { - primitiveTypeName(type) ?: if (type.isArray) { - "${nameForType(type.componentType)}${if (type.componentType.isPrimitive) "[p]" else "[]"}" - } else type.name - } - is ParameterizedType -> { - "${nameForType(type.rawType)}<${type.actualTypeArguments.joinToString { nameForType(it) }}>" - } - is GenericArrayType -> "${nameForType(type.genericComponentType)}[]" - is WildcardType -> "?" - is TypeVariable<*> -> "?" - else -> throw AMQPNotSerializableException(type, "Unable to render type $type to a string.") - } - } -} - -open class DefaultSerializerFactory( - override val whitelist: ClassWhitelist, - override val classCarpenter: ClassCarpenter, - private val evolutionSerializerProvider: EvolutionSerializerProvider, - override val fingerPrinterConstructor: (SerializerFactory) -> FingerPrinter, - private val onlyCustomSerializers: Boolean = false -) : SerializerFactory { - - // Caches - override val serializersByType: MutableMap> = DefaultCacheProvider.createCache() - override val serializersByDescriptor: MutableMap> = DefaultCacheProvider.createCache() - private var customSerializers: List = emptyList() - private val customSerializersCache: MutableMap?> = DefaultCacheProvider.createCache() - override val transformsCache: MutableMap>> = DefaultCacheProvider.createCache() - - override val fingerPrinter by lazy { fingerPrinterConstructor(this) } - - override val classloader: ClassLoader get() = classCarpenter.classloader - - // Used to short circuit any computation for a given input, for performance. - private data class MemoType(val actualClass: Class<*>?, val declaredType: Type) : Type - - /** - * Look up, and manufacture if necessary, a serializer for the given type. - * - * @param actualClass Will be null if there isn't an actual object instance available (e.g. for - * restricted type processing). - */ - @Throws(NotSerializableException::class) - override fun get(actualClass: Class<*>?, declaredType: Type): AMQPSerializer { - // can be useful to enable but will be *extremely* chatty if you do - logger.trace { "Get Serializer for $actualClass ${declaredType.typeName}" } - - val ourType = MemoType(actualClass, declaredType) - // ConcurrentHashMap.get() is lock free, but computeIfAbsent is not, even if the key is in the map already. - return serializersByType[ourType] ?: run { - - val declaredClass = declaredType.asClass() - val actualType: Type = if (actualClass == null) declaredType - else inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType - - val serializer = when { - // Declared class may not be set to Collection, but actual class could be a collection. - // In this case use of CollectionSerializer is perfectly appropriate. - (Collection::class.java.isAssignableFrom(declaredClass) || - (actualClass != null && Collection::class.java.isAssignableFrom(actualClass))) && - !EnumSet::class.java.isAssignableFrom(actualClass ?: declaredClass) -> { - val declaredTypeAmended = CollectionSerializer.deriveParameterizedType(declaredType, declaredClass, actualClass) - serializersByType.computeIfAbsent(declaredTypeAmended) { - CollectionSerializer(declaredTypeAmended, this) - } - } - // Declared class may not be set to Map, but actual class could be a map. - // In this case use of MapSerializer is perfectly appropriate. - (Map::class.java.isAssignableFrom(declaredClass) || - (actualClass != null && Map::class.java.isAssignableFrom(actualClass))) -> { - val declaredTypeAmended = MapSerializer.deriveParameterizedType(declaredType, declaredClass, actualClass) - serializersByType.computeIfAbsent(declaredTypeAmended) { - makeMapSerializer(declaredTypeAmended) - } - } - Enum::class.java.isAssignableFrom(actualClass ?: declaredClass) -> { - logger.trace { - "class=[${actualClass?.simpleName} | $declaredClass] is an enumeration " + - "declaredType=${declaredType.typeName} " + - "isEnum=${declaredType::class.java.isEnum}" - } - - serializersByType.computeIfAbsent(actualClass ?: declaredClass) { - whitelist.requireWhitelisted(actualType) - EnumSerializer(actualType, actualClass ?: declaredClass, this) - } - } - else -> { - makeClassSerializer(actualClass ?: declaredClass, actualType, declaredType) - } - } - - serializersByDescriptor.putIfAbsent(serializer.typeDescriptor, serializer) - // Always store the short-circuit too, for performance. - serializersByType.putIfAbsent(ourType, serializer) - return serializer - } - } - - /** - * Lookup and manufacture a serializer for the given AMQP type descriptor, assuming we also have the necessary types - * contained in the [Schema]. - */ - @Throws(NotSerializableException::class) - override fun get(typeDescriptor: Any, schema: SerializationSchemas): AMQPSerializer { - return serializersByDescriptor[typeDescriptor] ?: { - logger.trace("get Serializer descriptor=${typeDescriptor}") - processSchema(FactorySchemaAndDescriptor(schema, typeDescriptor)) - serializersByDescriptor[typeDescriptor] ?: throw NotSerializableException( - "Could not find type matching descriptor $typeDescriptor.") - }() - } - - /** - * Register a custom serializer for any type that cannot be serialized or deserialized by the default serializer - * that expects to find getters and a constructor with a parameter for each property. - */ - override fun register(customSerializer: CustomSerializer) { - logger.trace("action=\"Registering custom serializer\", class=\"${customSerializer.type}\"") - if (!serializersByDescriptor.containsKey(customSerializer.typeDescriptor)) { - customSerializers += customSerializer - serializersByDescriptor[customSerializer.typeDescriptor] = customSerializer - for (additional in customSerializer.additionalSerializers) { - register(additional) - } - } - } - - override fun registerExternal(customSerializer: CorDappCustomSerializer) { - logger.trace("action=\"Registering external serializer\", class=\"${customSerializer.type}\"") - if (!serializersByDescriptor.containsKey(customSerializer.typeDescriptor)) { - customSerializers += customSerializer - serializersByDescriptor[customSerializer.typeDescriptor] = customSerializer - } - } - - /** - * Iterate over an AMQP schema, for each type ascertain whether it's on ClassPath of [classloader] and, - * if not, use the [ClassCarpenter] to generate a class to use in its place. - */ - private fun processSchema(schemaAndDescriptor: FactorySchemaAndDescriptor, sentinel: Boolean = false) { - val requiringCarpentry = schemaAndDescriptor.schemas.schema.types.mapNotNull { typeNotation -> - try { - getOrRegisterSerializer(schemaAndDescriptor, typeNotation) - return@mapNotNull null - } catch (e: ClassNotFoundException) { - if (sentinel) { - logger.error("typeNotation=${typeNotation.name} error=\"after Carpentry attempt failed to load\"") - throw e - } - logger.trace { "typeNotation=\"${typeNotation.name}\" action=\"carpentry required\"" } - return@mapNotNull typeNotation - } - }.toList() - - if (requiringCarpentry.isEmpty()) return - - runCarpentry(schemaAndDescriptor, CarpenterMetaSchema.buildWith(classloader, requiringCarpentry)) - } - - private fun getOrRegisterSerializer(schemaAndDescriptor: FactorySchemaAndDescriptor, typeNotation: TypeNotation) { - logger.trace { "descriptor=${schemaAndDescriptor.typeDescriptor}, typeNotation=${typeNotation.name}" } - val serialiser = processSchemaEntry(typeNotation) - - // if we just successfully built a serializer for the type but the type fingerprint - // doesn't match that of the serialised object then we may be dealing with different - // instance of the class, and such we need to build an EvolutionSerializer - if (serialiser.typeDescriptor == typeNotation.descriptor.name) return - - logger.trace { "typeNotation=${typeNotation.name} action=\"requires Evolution\"" } - evolutionSerializerProvider.getEvolutionSerializer(this, typeNotation, serialiser, schemaAndDescriptor.schemas) - } - - private fun processSchemaEntry(typeNotation: TypeNotation) = when (typeNotation) { - // java.lang.Class (whether a class or interface) - is CompositeType -> { - logger.trace("typeNotation=${typeNotation.name} amqpType=CompositeType") - processCompositeType(typeNotation) - } - // Collection / Map, possibly with generics - is RestrictedType -> { - logger.trace("typeNotation=${typeNotation.name} amqpType=RestrictedType") - processRestrictedType(typeNotation) - } - } - - // TODO: class loader logic, and compare the schema. - private fun processRestrictedType(typeNotation: RestrictedType) = - get(null, typeForName(typeNotation.name, classloader)) - - private fun processCompositeType(typeNotation: CompositeType): AMQPSerializer { - // TODO: class loader logic, and compare the schema. - val type = typeForName(typeNotation.name, classloader) - return get(type.asClass(), type) - } - - private fun typeForName(name: String, classloader: ClassLoader): Type = when { - name.endsWith("[]") -> { - val elementType = typeForName(name.substring(0, name.lastIndex - 1), classloader) - if (elementType is ParameterizedType || elementType is GenericArrayType) { - DeserializedGenericArrayType(elementType) - } else if (elementType is Class<*>) { - java.lang.reflect.Array.newInstance(elementType, 0).javaClass - } else { - throw AMQPNoTypeNotSerializableException("Not able to deserialize array type: $name") - } - } - name.endsWith("[p]") -> // There is no need to handle the ByteArray case as that type is coercible automatically - // to the binary type and is thus handled by the main serializer and doesn't need a - // special case for a primitive array of bytes - when (name) { - "int[p]" -> IntArray::class.java - "char[p]" -> CharArray::class.java - "boolean[p]" -> BooleanArray::class.java - "float[p]" -> FloatArray::class.java - "double[p]" -> DoubleArray::class.java - "short[p]" -> ShortArray::class.java - "long[p]" -> LongArray::class.java - else -> throw AMQPNoTypeNotSerializableException("Not able to deserialize array type: $name") - } - else -> DeserializedParameterizedType.make(name, classloader) - } - - @StubOutForDJVM - private fun runCarpentry(schemaAndDescriptor: FactorySchemaAndDescriptor, metaSchema: CarpenterMetaSchema) { - val mc = MetaCarpenter(metaSchema, classCarpenter) - try { - mc.build() - } catch (e: MetaCarpenterException) { - // preserve the actual message locally - loggerFor().apply { - error("${e.message} [hint: enable trace debugging for the stack trace]") - trace("", e) - } - - // prevent carpenter exceptions escaping into the world, convert things into a nice - // NotSerializableException for when this escapes over the wire - NotSerializableException(e.name) - } - processSchema(schemaAndDescriptor, true) - } - - private fun makeClassSerializer( - clazz: Class<*>, - type: Type, - declaredType: Type - ): AMQPSerializer = serializersByType.computeIfAbsent(type) { - logger.debug { "class=${clazz.simpleName}, type=$type is a composite type" } - if (clazz.isSynthetic) { - // Explicitly ban synthetic classes, we have no way of recreating them when deserializing. This also - // captures Lambda expressions and other anonymous functions - throw AMQPNotSerializableException( - type, - "Serializer does not support synthetic classes") - } else if (SerializerFactory.isPrimitive(clazz)) { - AMQPPrimitiveSerializer(clazz) - } else { - findCustomSerializer(clazz, declaredType) ?: run { - if (onlyCustomSerializers) { - throw AMQPNotSerializableException(type, "Only allowing custom serializers") - } - if (type.isArray()) { - // Don't need to check the whitelist since each element will come back through the whitelisting process. - if (clazz.componentType.isPrimitive) PrimArraySerializer.make(type, this) - else ArraySerializer.make(type, this) - } else { - val singleton = clazz.kotlinObjectInstance - if (singleton != null) { - whitelist.requireWhitelisted(clazz) - SingletonSerializer(clazz, singleton, this) - } else { - whitelist.requireWhitelisted(type) - ObjectSerializer(type, this) - } - } - } - } - } - - override fun findCustomSerializer(clazz: Class<*>, declaredType: Type): AMQPSerializer? { - return customSerializersCache.computeIfAbsent(CustomSerializersCacheKey(clazz, declaredType), ::doFindCustomSerializer) - } - - private fun doFindCustomSerializer(key: CustomSerializersCacheKey): AMQPSerializer? { - val (clazz, declaredType) = key - - // e.g. Imagine if we provided a Map serializer this way, then it won't work if the declared type is - // AbstractMap, only Map. Otherwise it needs to inject additional schema for a RestrictedType source of the - // super type. Could be done, but do we need it? - for (customSerializer in customSerializers) { - if (customSerializer.isSerializerFor(clazz)) { - val declaredSuperClass = declaredType.asClass().superclass - - - return if (declaredSuperClass == null - || !customSerializer.isSerializerFor(declaredSuperClass) - || !customSerializer.revealSubclassesInSchema - ) { - logger.debug("action=\"Using custom serializer\", class=${clazz.typeName}, " + - "declaredType=${declaredType.typeName}") - - @Suppress("UNCHECKED_CAST") - customSerializer as? AMQPSerializer - } else { - // Make a subclass serializer for the subclass and return that... - CustomSerializer.SubClass(clazz, uncheckedCast(customSerializer)) - } - } - } - return null - } - - private fun makeMapSerializer(declaredType: ParameterizedType): AMQPSerializer { - val rawType = declaredType.rawType as Class<*> - rawType.checkSupportedMapType() - return MapSerializer(declaredType, this) - } - - override fun registerByDescriptor(name: Symbol, serializerCreator: () -> AMQPSerializer): AMQPSerializer = - serializersByDescriptor.computeIfAbsent(name) { _ -> serializerCreator() } - - companion object { - private val logger = contextLogger() - } - -} \ No newline at end of file +class ComposedSerializerFactory( + private val localSerializerFactory: LocalSerializerFactory, + private val remoteSerializerFactory: RemoteSerializerFactory, + private val customSerializerRegistry: CachingCustomSerializerRegistry +) : SerializerFactory, + LocalSerializerFactory by localSerializerFactory, + RemoteSerializerFactory by remoteSerializerFactory, + CustomSerializerRegistry by customSerializerRegistry \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactoryBuilder.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactoryBuilder.kt index 0d05d4e0de..665eda5c56 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactoryBuilder.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactoryBuilder.kt @@ -5,49 +5,126 @@ import net.corda.core.KeepForDJVM import net.corda.core.serialization.ClassWhitelist import net.corda.serialization.internal.carpenter.ClassCarpenter import net.corda.serialization.internal.carpenter.ClassCarpenterImpl +import net.corda.serialization.internal.model.* +import java.io.NotSerializableException @KeepForDJVM object SerializerFactoryBuilder { - + @JvmStatic - @JvmOverloads - fun build( - whitelist: ClassWhitelist, - classCarpenter: ClassCarpenter, - evolutionSerializerProvider: EvolutionSerializerProvider = DefaultEvolutionSerializerProvider, - fingerPrinterProvider: (SerializerFactory) -> FingerPrinter = ::SerializerFingerPrinter, - onlyCustomSerializers: Boolean = false): SerializerFactory { + fun build(whitelist: ClassWhitelist, classCarpenter: ClassCarpenter): SerializerFactory { return makeFactory( whitelist, classCarpenter, - evolutionSerializerProvider, - fingerPrinterProvider, - onlyCustomSerializers) + DefaultDescriptorBasedSerializerRegistry(), + true, + null, + false, + false) + } + + @JvmStatic + @DeleteForDJVM + fun build( + whitelist: ClassWhitelist, + classCarpenter: ClassCarpenter, + descriptorBasedSerializerRegistry: DescriptorBasedSerializerRegistry = + DefaultDescriptorBasedSerializerRegistry(), + allowEvolution: Boolean = true, + overrideFingerPrinter: FingerPrinter? = null, + onlyCustomSerializers: Boolean = false, + mustPreserveDataWhenEvolving: Boolean = false): SerializerFactory { + return makeFactory( + whitelist, + classCarpenter, + descriptorBasedSerializerRegistry, + allowEvolution, + overrideFingerPrinter, + onlyCustomSerializers, + mustPreserveDataWhenEvolving) } @JvmStatic - @JvmOverloads @DeleteForDJVM fun build( whitelist: ClassWhitelist, carpenterClassLoader: ClassLoader, lenientCarpenterEnabled: Boolean = false, - evolutionSerializerProvider: EvolutionSerializerProvider = DefaultEvolutionSerializerProvider, - fingerPrinterProvider: (SerializerFactory) -> FingerPrinter = ::SerializerFingerPrinter, - onlyCustomSerializers: Boolean = false): SerializerFactory { + descriptorBasedSerializerRegistry: DescriptorBasedSerializerRegistry = + DefaultDescriptorBasedSerializerRegistry(), + allowEvolution: Boolean = true, + overrideFingerPrinter: FingerPrinter? = null, + onlyCustomSerializers: Boolean = false, + mustPreserveDataWhenEvolving: Boolean = false): SerializerFactory { return makeFactory( whitelist, ClassCarpenterImpl(whitelist, carpenterClassLoader, lenientCarpenterEnabled), - evolutionSerializerProvider, - fingerPrinterProvider, - onlyCustomSerializers) + descriptorBasedSerializerRegistry, + allowEvolution, + overrideFingerPrinter, + onlyCustomSerializers, + mustPreserveDataWhenEvolving) } private fun makeFactory(whitelist: ClassWhitelist, classCarpenter: ClassCarpenter, - evolutionSerializerProvider: EvolutionSerializerProvider, - fingerPrinterProvider: (SerializerFactory) -> FingerPrinter, - onlyCustomSerializers: Boolean) = - DefaultSerializerFactory(whitelist, classCarpenter, evolutionSerializerProvider, fingerPrinterProvider, - onlyCustomSerializers) + descriptorBasedSerializerRegistry: DescriptorBasedSerializerRegistry, + allowEvolution: Boolean, + overrideFingerPrinter: FingerPrinter?, + onlyCustomSerializers: Boolean, + mustPreserveDataWhenEvolving: Boolean): SerializerFactory { + val customSerializerRegistry = CachingCustomSerializerRegistry(descriptorBasedSerializerRegistry) + + val localTypeModel = ConfigurableLocalTypeModel( + WhitelistBasedTypeModelConfiguration( + whitelist, + customSerializerRegistry)) + + val fingerPrinter = overrideFingerPrinter ?: + TypeModellingFingerPrinter(customSerializerRegistry) + + val localSerializerFactory = DefaultLocalSerializerFactory( + whitelist, + localTypeModel, + fingerPrinter, + classCarpenter.classloader, + descriptorBasedSerializerRegistry, + customSerializerRegistry, + onlyCustomSerializers) + + val typeLoader = ClassCarpentingTypeLoader( + SchemaBuildingRemoteTypeCarpenter(classCarpenter), + classCarpenter.classloader) + + val evolutionSerializerFactory = if (allowEvolution) DefaultEvolutionSerializerFactory( + localSerializerFactory, + classCarpenter.classloader, + mustPreserveDataWhenEvolving + ) else NoEvolutionSerializerFactory + + val remoteSerializerFactory = DefaultRemoteSerializerFactory( + evolutionSerializerFactory, + descriptorBasedSerializerRegistry, + AMQPRemoteTypeModel(), + localTypeModel, + typeLoader, + localSerializerFactory) + + return ComposedSerializerFactory(localSerializerFactory, remoteSerializerFactory, customSerializerRegistry) + } + +} + +object NoEvolutionSerializerFactory : EvolutionSerializerFactory { + override fun getEvolutionSerializer(remoteTypeInformation: RemoteTypeInformation, localTypeInformation: LocalTypeInformation): AMQPSerializer { + throw NotSerializableException(""" +Evolution not permitted. + +Remote: +${remoteTypeInformation.prettyPrint(false)} + +Local: +${localTypeInformation.prettyPrint(false)} + """) + } } \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SingletonSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SingletonSerializer.kt index 584501a877..8328620504 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SingletonSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SingletonSerializer.kt @@ -1,6 +1,7 @@ package net.corda.serialization.internal.amqp import net.corda.core.serialization.SerializationContext +import net.corda.serialization.internal.model.LocalTypeInformation import org.apache.qpid.proton.amqp.Symbol import org.apache.qpid.proton.codec.Data import java.lang.reflect.Type @@ -10,13 +11,12 @@ import java.lang.reflect.Type * absolutely nothing, or null as a described type) when we have a singleton within the node that we just * want converting back to that singleton instance on the receiving JVM. */ -class SingletonSerializer(override val type: Class<*>, val singleton: Any, factory: SerializerFactory) : AMQPSerializer { - override val typeDescriptor = Symbol.valueOf( - "$DESCRIPTOR_DOMAIN:${factory.fingerPrinter.fingerprint(type)}")!! +class SingletonSerializer(override val type: Class<*>, val singleton: Any, factory: LocalSerializerFactory) : AMQPSerializer { + override val typeDescriptor = factory.createDescriptor(type) - private val interfaces = interfacesForSerialization(type, factory) + private val interfaces = (factory.getTypeInformation(type) as LocalTypeInformation.Singleton).interfaces - private fun generateProvides(): List = interfaces.map { it.typeName } + private fun generateProvides(): List = interfaces.map { it.typeIdentifier.name } internal val typeNotation: TypeNotation = RestrictedType(type.typeName, "Singleton", generateProvides(), "boolean", Descriptor(typeDescriptor), emptyList()) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TransformsSchema.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TransformsSchema.kt index 1409bc95c7..15afeb60b1 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TransformsSchema.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TransformsSchema.kt @@ -7,6 +7,7 @@ import net.corda.core.utilities.contextLogger import net.corda.core.utilities.trace import net.corda.serialization.internal.NotSerializableDetailedException import net.corda.serialization.internal.NotSerializableWithReasonException +import net.corda.serialization.internal.model.DefaultCacheProvider import org.apache.qpid.proton.amqp.DescribedType import org.apache.qpid.proton.codec.DescribedTypeConstructor import java.io.NotSerializableException @@ -207,7 +208,8 @@ data class TransformsSchema(val types: Map>(TransformTypes::class.java) try { val clazz = sf.classloader.loadClass(name) @@ -244,7 +246,7 @@ data class TransformsSchema(val types: Map>>) { try { get(type, sf).apply { @@ -268,7 +270,7 @@ data class TransformsSchema(val types: Map>>().apply { schema.types.forEach { type -> getAndAdd(type.name, sf, this) } }) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TypeNotationGenerator.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TypeNotationGenerator.kt new file mode 100644 index 0000000000..802c1df68e --- /dev/null +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TypeNotationGenerator.kt @@ -0,0 +1,73 @@ +package net.corda.serialization.internal.amqp + +import net.corda.serialization.internal.model.LocalPropertyInformation +import net.corda.serialization.internal.model.LocalTypeInformation +import net.corda.serialization.internal.model.TypeIdentifier +import org.apache.qpid.proton.amqp.Symbol +import java.io.NotSerializableException + +object TypeNotationGenerator { + + fun getTypeNotation(typeInformation: LocalTypeInformation, typeDescriptor: Symbol) = when(typeInformation) { + is LocalTypeInformation.AnInterface -> typeInformation.getTypeNotation(typeDescriptor) + is LocalTypeInformation.Composable -> typeInformation.getTypeNotation(typeDescriptor) + is LocalTypeInformation.Abstract -> typeInformation.getTypeNotation(typeDescriptor) + else -> throw NotSerializableException("Cannot generate type notation for $typeInformation") + } + + private val LocalTypeInformation.amqpTypeName get() = AMQPTypeIdentifiers.nameForType(typeIdentifier) + + private fun LocalTypeInformation.AnInterface.getTypeNotation(typeDescriptor: Symbol): CompositeType = + makeCompositeType( + (sequenceOf(this) + interfaces.asSequence()).toList(), + properties, + typeDescriptor) + + private fun LocalTypeInformation.Composable.getTypeNotation(typeDescriptor: Symbol): CompositeType = + makeCompositeType(interfaces, properties, typeDescriptor) + + private fun LocalTypeInformation.Abstract.getTypeNotation(typeDescriptor: Symbol): CompositeType = + makeCompositeType(interfaces, properties, typeDescriptor) + + private fun LocalTypeInformation.makeCompositeType( + interfaces: List, + properties: Map, + typeDescriptor: Symbol): CompositeType { + val provides = interfaces.map { it.amqpTypeName } + val fields = properties.map { (name, property) -> + property.getField(name) + } + + return CompositeType( + amqpTypeName, + null, + provides, + Descriptor(typeDescriptor), + fields) + } + + private fun LocalPropertyInformation.getField(name: String): Field { + val (typeName, requires) = when(type) { + is LocalTypeInformation.AnInterface, + is LocalTypeInformation.ACollection, + is LocalTypeInformation.AMap -> "*" to listOf(type.amqpTypeName) + else -> type.amqpTypeName to emptyList() + } + + val defaultValue: String? = defaultValues[type.typeIdentifier] + + return Field(name, typeName, requires, defaultValue, null, isMandatory, false) + } + + private val defaultValues = sequenceOf( + Boolean::class to "false", + Byte::class to "0", + Int::class to "0", + Char::class to "�", + Short::class to "0", + Long::class to "0", + Float::class to "0", + Double::class to "0").associate { (type, value) -> + TypeIdentifier.forClass(type.javaPrimitiveType!!) to value + } +} \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TypeParameterUtils.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TypeParameterUtils.kt index 72720add79..f16f58b0ae 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TypeParameterUtils.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/TypeParameterUtils.kt @@ -7,7 +7,6 @@ import java.lang.reflect.* * Try and infer concrete types for any generics type variables for the actual class encountered, * based on the declared type. */ -// TODO: test GenericArrayType fun inferTypeVariables(actualClass: Class<*>, declaredClass: Class<*>, declaredType: Type): Type? = when (declaredType) { @@ -17,10 +16,7 @@ fun inferTypeVariables(actualClass: Class<*>, inferTypeVariables(actualClass.componentType, declaredComponent.asClass(), declaredComponent)?.asArray() } // Nothing to infer, otherwise we'd have ParameterizedType - is Class<*> -> actualClass - is TypeVariable<*> -> actualClass - is WildcardType -> actualClass - else -> throw UnsupportedOperationException("Cannot infer type variables for type $declaredType") + else -> actualClass } /** @@ -32,12 +28,6 @@ private fun inferTypeVariables(actualClass: Class<*>, declaredClass: Class<*>, d return null } - if (!declaredClass.isAssignableFrom(actualClass)) { - throw AMQPNotSerializableException( - declaredType, - "Found object of type $actualClass in a property expecting $declaredType") - } - if (actualClass.typeParameters.isEmpty()) { return actualClass } @@ -55,7 +45,7 @@ private fun inferTypeVariables(actualClass: Class<*>, declaredClass: Class<*>, d TypeResolver().where(chainEntry, newResolved) } // The end type is a special case as it is a Class, so we need to fake up a ParameterizedType for it to get the TypeResolver to do anything. - val endType = DeserializedParameterizedType(actualClass, actualClass.typeParameters) + val endType = actualClass.asParameterizedType() return resolver.resolveType(endType) } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/ClassSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/ClassSerializer.kt index dd4cbd9f9f..8d60f06e0c 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/ClassSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/ClassSerializer.kt @@ -5,6 +5,7 @@ import net.corda.core.utilities.contextLogger import net.corda.core.utilities.trace import net.corda.serialization.internal.amqp.AMQPNotSerializableException import net.corda.serialization.internal.amqp.CustomSerializer +import net.corda.serialization.internal.amqp.LocalSerializerFactory import net.corda.serialization.internal.amqp.SerializerFactory import net.corda.serialization.internal.amqp.custom.ClassSerializer.ClassProxy @@ -12,7 +13,7 @@ import net.corda.serialization.internal.amqp.custom.ClassSerializer.ClassProxy * A serializer for [Class] that uses [ClassProxy] proxy object to write out */ class ClassSerializer( - factory: SerializerFactory + factory: LocalSerializerFactory ) : CustomSerializer.Proxy, ClassSerializer.ClassProxy>( Class::class.java, ClassProxy::class.java, diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/InputStreamSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/InputStreamSerializer.kt index 1a7f5fce89..46d2ae80e3 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/InputStreamSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/InputStreamSerializer.kt @@ -20,7 +20,7 @@ object InputStreamSerializer : CustomSerializer.Implements(InputStr type.toString(), "", listOf(type.toString()), - SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, + AMQPTypeIdentifiers.primitiveTypeName(ByteArray::class.java), descriptor, emptyList()))) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PrivateKeySerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PrivateKeySerializer.kt index 7bf9bbf344..118ca2312d 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PrivateKeySerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PrivateKeySerializer.kt @@ -8,11 +8,10 @@ import net.corda.serialization.internal.checkUseCase import org.apache.qpid.proton.codec.Data import java.lang.reflect.Type import java.security.PrivateKey -import java.util.* object PrivateKeySerializer : CustomSerializer.Implements(PrivateKey::class.java) { - override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, descriptor, emptyList()))) + override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), AMQPTypeIdentifiers.primitiveTypeName(ByteArray::class.java), descriptor, emptyList()))) override fun writeDescribedObject(obj: PrivateKey, data: Data, type: Type, output: SerializationOutput, context: SerializationContext diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PublicKeySerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PublicKeySerializer.kt index bf6025360d..ef1be88fce 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PublicKeySerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/PublicKeySerializer.kt @@ -11,7 +11,7 @@ import java.security.PublicKey * A serializer that writes out a public key in X.509 format. */ object PublicKeySerializer : CustomSerializer.Implements(PublicKey::class.java) { - override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, descriptor, emptyList()))) + override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), AMQPTypeIdentifiers.primitiveTypeName(ByteArray::class.java), descriptor, emptyList()))) override fun writeDescribedObject(obj: PublicKey, data: Data, type: Type, output: SerializationOutput, context: SerializationContext diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/ThrowableSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/ThrowableSerializer.kt index 3b4b03800e..b3082a629e 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/ThrowableSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/ThrowableSerializer.kt @@ -6,10 +6,13 @@ import net.corda.core.KeepForDJVM import net.corda.core.serialization.SerializationFactory import net.corda.core.utilities.contextLogger import net.corda.serialization.internal.amqp.* +import net.corda.serialization.internal.model.LocalConstructorInformation +import net.corda.serialization.internal.model.LocalPropertyInformation +import net.corda.serialization.internal.model.LocalTypeInformation import java.io.NotSerializableException @KeepForDJVM -class ThrowableSerializer(factory: SerializerFactory) : CustomSerializer.Proxy(Throwable::class.java, ThrowableProxy::class.java, factory) { +class ThrowableSerializer(factory: LocalSerializerFactory) : CustomSerializer.Proxy(Throwable::class.java, ThrowableProxy::class.java, factory) { companion object { private val logger = contextLogger() @@ -19,15 +22,23 @@ class ThrowableSerializer(factory: SerializerFactory) : CustomSerializer.Proxy> = listOf(StackTraceElementSerializer(factory)) + private val LocalTypeInformation.constructor: LocalConstructorInformation get() = when(this) { + is LocalTypeInformation.NonComposable -> constructor ?: + throw NotSerializableException("$this has no deserialization constructor") + is LocalTypeInformation.Composable -> constructor + is LocalTypeInformation.Opaque -> expand.constructor + else -> throw NotSerializableException("$this has no deserialization constructor") + } + override fun toProxy(obj: Throwable): ThrowableProxy { val extraProperties: MutableMap = LinkedHashMap() val message = if (obj is CordaThrowable) { // Try and find a constructor try { - val constructor = constructorForDeserialization(obj.javaClass) - propertiesForSerializationFromConstructor(constructor, obj.javaClass, factory).forEach { property -> - extraProperties[property.serializer.name] = property.serializer.propertyReader.read(obj) - } + val typeInformation = factory.getTypeInformation(obj.javaClass) + extraProperties.putAll(typeInformation.propertiesOrEmptyMap.mapValues { (_, property) -> + PropertyReader.make(property).read(obj) + }) } catch (e: NotSerializableException) { logger.warn("Unexpected exception", e) } @@ -52,8 +63,13 @@ class ThrowableSerializer(factory: SerializerFactory) : CustomSerializer.Proxy + proxy.additionalProperties[parameter.name] ?: + proxy.additionalProperties[parameter.name.capitalize()] + } + val throwable = constructor.observedMethod.call(*params.toTypedArray()) (throwable as CordaThrowable).apply { if (this.javaClass.name != proxy.exceptionClass) this.originalExceptionClassName = proxy.exceptionClass this.setMessage(proxy.message) @@ -85,7 +101,7 @@ class ThrowableSerializer(factory: SerializerFactory) : CustomSerializer.Proxy) } -class StackTraceElementSerializer(factory: SerializerFactory) : CustomSerializer.Proxy(StackTraceElement::class.java, StackTraceElementProxy::class.java, factory) { +class StackTraceElementSerializer(factory: LocalSerializerFactory) : CustomSerializer.Proxy(StackTraceElement::class.java, StackTraceElementProxy::class.java, factory) { override fun toProxy(obj: StackTraceElement): StackTraceElementProxy = StackTraceElementProxy(obj.className, obj.methodName, obj.fileName, obj.lineNumber) override fun fromProxy(proxy: StackTraceElementProxy): StackTraceElement = StackTraceElement(proxy.declaringClass, proxy.methodName, proxy.fileName, proxy.lineNumber) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CRLSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CRLSerializer.kt index 9a5a06b62e..e7374dd114 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CRLSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CRLSerializer.kt @@ -12,7 +12,7 @@ object X509CRLSerializer : CustomSerializer.Implements(X509CRL::class.j type.toString(), "", listOf(type.toString()), - SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, + AMQPTypeIdentifiers.primitiveTypeName(ByteArray::class.java), descriptor, emptyList() ))) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CertificateSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CertificateSerializer.kt index 5d00cef9b0..90063aaa99 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CertificateSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/custom/X509CertificateSerializer.kt @@ -12,7 +12,7 @@ object X509CertificateSerializer : CustomSerializer.Implements( type.toString(), "", listOf(type.toString()), - SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, + AMQPTypeIdentifiers.primitiveTypeName(ByteArray::class.java), descriptor, emptyList() ))) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/AMQPSchemaExtensions.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/AMQPSchemaExtensions.kt deleted file mode 100644 index 37736abcde..0000000000 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/AMQPSchemaExtensions.kt +++ /dev/null @@ -1,154 +0,0 @@ -@file:JvmName("AMQPSchemaExtensions") - -package net.corda.serialization.internal.carpenter - -import net.corda.core.DeleteForDJVM -import net.corda.core.serialization.SerializationContext -import net.corda.serialization.internal.amqp.CompositeType -import net.corda.serialization.internal.amqp.RestrictedType -import net.corda.serialization.internal.amqp.Field as AMQPField -import net.corda.serialization.internal.amqp.Schema as AMQPSchema - -@DeleteForDJVM -fun AMQPSchema.carpenterSchema(classloader: ClassLoader): CarpenterMetaSchema { - val rtn = CarpenterMetaSchema.newInstance() - - types.filterIsInstance().forEach { - it.carpenterSchema(classloader, carpenterSchemas = rtn) - } - - return rtn -} - -/** - * if we can load the class then we MUST know about all of it's composite elements - */ -private fun CompositeType.validatePropertyTypes(classloader: ClassLoader) { - fields.forEach { - if (!it.validateType(classloader)) throw UncarpentableException(name, it.name, it.type) - } -} - -fun AMQPField.typeAsString() = if (type == "*") requires[0] else type - -/** - * based upon this AMQP schema either - * a) add the corresponding carpenter schema to the [carpenterSchemas] param - * b) add the class to the dependency tree in [carpenterSchemas] if it cannot be instantiated - * at this time - * - * @param classloader the class loader provided by the [SerializationContext] - * @param carpenterSchemas structure that holds the dependency tree and list of classes that - * need constructing - * @param force by default a schema is not added to [carpenterSchemas] if it already exists - * on the class path. For testing purposes schema generation can be forced - */ -@DeleteForDJVM -fun CompositeType.carpenterSchema(classloader: ClassLoader, - carpenterSchemas: CarpenterMetaSchema, - force: Boolean = false) { - if (classloader.exists(name)) { - validatePropertyTypes(classloader) - if (!force) return - } - - val providesList = mutableListOf>() - var isInterface = false - var isCreatable = true - - provides.forEach { - if (name == it) { - isInterface = true - return@forEach - } - - try { - providesList.add(classloader.loadClass(it.stripGenerics())) - } catch (e: ClassNotFoundException) { - carpenterSchemas.addDepPair(this, name, it) - isCreatable = false - } - } - - val m: MutableMap = mutableMapOf() - - fields.forEach { - try { - m[it.name] = FieldFactory.newInstance(it.mandatory, it.name, it.getTypeAsClass(classloader)) - } catch (e: ClassNotFoundException) { - carpenterSchemas.addDepPair(this, name, it.typeAsString()) - isCreatable = false - } - } - - if (isCreatable) { - carpenterSchemas.carpenterSchemas.add(CarpenterSchemaFactory.newInstance( - name = name, - fields = m, - interfaces = providesList, - isInterface = isInterface)) - } -} - -// This is potentially problematic as we're assuming the only type of restriction we will be -// carpenting for, an enum, but actually trying to split out RestrictedType into something -// more polymorphic is hard. Additionally, to conform to AMQP we're really serialising -// this as a list so... -@DeleteForDJVM -fun RestrictedType.carpenterSchema(carpenterSchemas: CarpenterMetaSchema) { - val m: MutableMap = mutableMapOf() - - choices.forEach { m[it.name] = EnumField() } - - carpenterSchemas.carpenterSchemas.add(EnumSchema(name = name, fields = m)) -} - -// map a pair of (typename, mandatory) to the corresponding class type -// where the mandatory AMQP flag maps to the types nullability -val typeStrToType: Map, Class> = mapOf( - Pair("int", true) to Int::class.javaPrimitiveType!!, - Pair("int", false) to Integer::class.javaObjectType, - Pair("short", true) to Short::class.javaPrimitiveType!!, - Pair("short", false) to Short::class.javaObjectType, - Pair("long", true) to Long::class.javaPrimitiveType!!, - Pair("long", false) to Long::class.javaObjectType, - Pair("char", true) to Char::class.javaPrimitiveType!!, - Pair("char", false) to java.lang.Character::class.java, - Pair("boolean", true) to Boolean::class.javaPrimitiveType!!, - Pair("boolean", false) to Boolean::class.javaObjectType, - Pair("double", true) to Double::class.javaPrimitiveType!!, - Pair("double", false) to Double::class.javaObjectType, - Pair("float", true) to Float::class.javaPrimitiveType!!, - Pair("float", false) to Float::class.javaObjectType, - Pair("byte", true) to Byte::class.javaPrimitiveType!!, - Pair("byte", false) to Byte::class.javaObjectType -) - -fun String.stripGenerics(): String = if (this.endsWith('>')) { - this.substring(0, this.indexOf('<')) -} else this - -fun AMQPField.getTypeAsClass(classloader: ClassLoader) = (typeStrToType[Pair(type, mandatory)] ?: when (type) { - "string" -> String::class.java - "binary" -> ByteArray::class.java - "*" -> if (requires.isEmpty()) Any::class.java else { - classloader.loadClass(requires[0].stripGenerics()) - } - else -> classloader.loadClass(type.stripGenerics()) -})!! - -fun AMQPField.validateType(classloader: ClassLoader) = - when (type) { - "byte", "int", "string", "short", "long", "char", "boolean", "double", "float" -> true - "*" -> classloader.exists(requires[0]) - else -> classloader.exists(type) - } - -private fun ClassLoader.exists(clazz: String) = run { - try { - this.loadClass(clazz); true - } catch (e: ClassNotFoundException) { - false - } -} - diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/Exceptions.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/Exceptions.kt index ac1346b053..ffd6096f04 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/Exceptions.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/Exceptions.kt @@ -23,14 +23,3 @@ class NullablePrimitiveException(val name: String, val field: Class) : class UncarpentableException(name: String, field: String, type: String) : ClassCarpenterException("Class $name is loadable yet contains field $field of unknown type $type") - -/** - * A meta exception used by the [MetaCarpenter] to wrap any exceptions generated during the build - * process and associate those with the current schema being processed. This makes for cleaner external - * error hand - * - * @property name The name of the schema, and thus the class being created, when the error was occured - * @property e The [ClassCarpenterException] this is wrapping - */ -class MetaCarpenterException(val name: String, val e: ClassCarpenterException) : CordaRuntimeException( - "Whilst processing class '$name' - ${e.message}") diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/MetaCarpenter.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/MetaCarpenter.kt deleted file mode 100644 index 445c3ce7da..0000000000 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/MetaCarpenter.kt +++ /dev/null @@ -1,127 +0,0 @@ -package net.corda.serialization.internal.carpenter - -import net.corda.core.DeleteForDJVM -import net.corda.core.KeepForDJVM -import net.corda.core.StubOutForDJVM -import net.corda.serialization.internal.amqp.CompositeType -import net.corda.serialization.internal.amqp.RestrictedType -import net.corda.serialization.internal.amqp.TypeNotation - -/** - * Generated from an AMQP schema this class represents the classes unknown to the deserializer and that thusly - * require carpenting up in bytecode form. This is a multi step process as carpenting one object may be dependent - * upon the creation of others, this information is tracked in the dependency tree represented by - * [dependencies] and [dependsOn]. Creatable classes are stored in [carpenterSchemas]. - * - * The state of this class after initial generation is expected to mutate as classes are built by the carpenter - * enabling the resolution of dependencies and thus new carpenter schemas added whilst those already - * carpented schemas are removed. - * - * @property carpenterSchemas The list of carpentable classes - * @property dependencies Maps a class to a list of classes that depend on it being built first - * @property dependsOn Maps a class to a list of classes it depends on being built before it - * - * Once a class is constructed we can quickly check for resolution by first looking at all of its dependents in the - * [dependencies] map. This will give us a list of classes that depended on that class being carpented. We can then - * in turn look up all of those classes in the [dependsOn] list, remove their dependency on the newly created class, - * and if that list is reduced to zero know we can now generate a [Schema] for them and carpent them up - */ -@KeepForDJVM -data class CarpenterMetaSchema( - val carpenterSchemas: MutableList, - val dependencies: MutableMap>>, - val dependsOn: MutableMap>) { - companion object CarpenterSchemaConstructor { - fun buildWith(classLoader: ClassLoader, types: List) = - newInstance().apply { - types.forEach { buildFor(it, classLoader) } - } - - fun newInstance(): CarpenterMetaSchema { - return CarpenterMetaSchema(mutableListOf(), mutableMapOf(), mutableMapOf()) - } - } - - fun addDepPair(type: TypeNotation, dependant: String, dependee: String) { - dependsOn.computeIfAbsent(dependee, { mutableListOf() }).add(dependant) - dependencies.computeIfAbsent(dependant, { Pair(type, mutableListOf()) }).second.add(dependee) - } - - val size - get() = carpenterSchemas.size - - fun isEmpty() = carpenterSchemas.isEmpty() - fun isNotEmpty() = carpenterSchemas.isNotEmpty() - - // We could make this an abstract method on TypeNotation but that - // would mean the amqp package being "more" infected with carpenter - // specific bits. - @StubOutForDJVM - fun buildFor(target: TypeNotation, cl: ClassLoader): Unit = when (target) { - is RestrictedType -> target.carpenterSchema(this) - is CompositeType -> target.carpenterSchema(cl, this, false) - } -} - -/** - * Take a dependency tree of [CarpenterMetaSchema] and reduce it to zero by carpenting those classes that - * require it. As classes are carpented check for dependency resolution, if now free generate a [Schema] for - * that class and add it to the list of classes ([CarpenterMetaSchema.carpenterSchemas]) that require - * carpenting - * - * @property cc a reference to the actual class carpenter we're using to constuct classes - * @property objects a list of carpented classes loaded into the carpenters class loader - */ -@DeleteForDJVM -abstract class MetaCarpenterBase(val schemas: CarpenterMetaSchema, val cc: ClassCarpenter) { - val objects = mutableMapOf>() - - fun step(newObject: Schema) { - objects[newObject.name] = cc.build(newObject) - - // go over the list of everything that had a dependency on the newly - // carpented class existing and remove it from their dependency list, If that - // list is now empty we have no impediment to carpenting that class up - schemas.dependsOn.remove(newObject.name)?.forEach { dependent -> - - require(newObject.name in schemas.dependencies[dependent]!!.second) - - schemas.dependencies[dependent]?.second?.remove(newObject.name) - - // we're out of blockers so we can now create the type - if (schemas.dependencies[dependent]?.second?.isEmpty() == true) { - (schemas.dependencies.remove(dependent)?.first as CompositeType).carpenterSchema( - classloader = cc.classloader, - carpenterSchemas = schemas) - } - } - } - - abstract fun build() - - val classloader: ClassLoader - get() = cc.classloader -} - -@DeleteForDJVM -class MetaCarpenter(schemas: CarpenterMetaSchema, cc: ClassCarpenter) : MetaCarpenterBase(schemas, cc) { - override fun build() { - while (schemas.carpenterSchemas.isNotEmpty()) { - val newObject = schemas.carpenterSchemas.removeAt(0) - try { - step(newObject) - } catch (e: ClassCarpenterException) { - throw MetaCarpenterException(newObject.name, e) - } - } - } -} - -@DeleteForDJVM -class TestMetaCarpenter(schemas: CarpenterMetaSchema, cc: ClassCarpenter) : MetaCarpenterBase(schemas, cc) { - override fun build() { - if (schemas.carpenterSchemas.isEmpty()) return - step(schemas.carpenterSchemas.removeAt(0)) - } -} - diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/Schema.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/Schema.kt index 8690668367..90a1034f70 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/Schema.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/carpenter/Schema.kt @@ -74,7 +74,7 @@ fun EnumMap.simpleFieldAccess(): Boolean { class ClassSchema( name: String, fields: Map, - superclass: Schema? = null, + superclass: Schema? = null, // always null for now, but retained because non-null superclass is supported by carpenter. interfaces: List> = emptyList() ) : Schema(name, fields, superclass, interfaces, { newName, field -> field.name = newName }) { override fun generateFields(cw: ClassWriter) { @@ -128,11 +128,10 @@ object CarpenterSchemaFactory { fun newInstance( name: String, fields: Map, - superclass: Schema? = null, interfaces: List> = emptyList(), isInterface: Boolean = false ): Schema = - if (isInterface) InterfaceSchema(name, fields, superclass, interfaces) - else ClassSchema(name, fields, superclass, interfaces) + if (isInterface) InterfaceSchema(name, fields, null, interfaces) + else ClassSchema(name, fields, null, interfaces) } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/model/CarpentryDependencyGraph.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/model/CarpentryDependencyGraph.kt index 6b6edfac88..d9959672e1 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/model/CarpentryDependencyGraph.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/model/CarpentryDependencyGraph.kt @@ -6,8 +6,7 @@ import java.lang.reflect.Type /** * Once we have the complete graph of types requiring carpentry to hand, we can use it to sort those types in reverse- * dependency order, i.e. beginning with those types that have no dependencies on other types, then the types that - * depended on those types, and so on. This means we can feed types directly to the [RemoteTypeCarpenter], and don't - * have to use the [CarpenterMetaSchema]. + * depended on those types, and so on. This means we can feed types in this order directly to the [RemoteTypeCarpenter]. * * @param typesRequiringCarpentry The set of [RemoteTypeInformation] for types that are not reachable by the current * classloader. diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformation.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformation.kt index 206417138b..c7473623d4 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformation.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformation.kt @@ -56,8 +56,19 @@ sealed class LocalTypeInformation { * @param type The [Type] to obtain [LocalTypeInformation] for. * @param lookup The [LocalTypeLookup] to use to find previously-constructed [LocalTypeInformation]. */ - fun forType(type: Type, lookup: LocalTypeLookup): LocalTypeInformation = - LocalTypeInformationBuilder(lookup).build(type, TypeIdentifier.forGenericType(type)) + fun forType(type: Type, lookup: LocalTypeLookup): LocalTypeInformation { + val builder = LocalTypeInformationBuilder(lookup) + val result = builder.build(type, TypeIdentifier.forGenericType(type)) + + // Patch every cyclic reference with a `follow` property pointing to the type information it refers to. + builder.cycles.forEach { cycle -> + cycle.follow = lookup.findOrBuild(cycle.observedType, cycle.typeIdentifier) { + throw IllegalStateException("Should not be attempting to build new type information when populating a cycle") + } + } + + return result + } } /** @@ -71,6 +82,29 @@ sealed class LocalTypeInformation { */ abstract val typeIdentifier: TypeIdentifier + /** + * Get the map of [LocalPropertyInformation], for all types that have it, or an empty map otherwise. + */ + val propertiesOrEmptyMap: Map get() = when(this) { + is LocalTypeInformation.Composable -> properties + is LocalTypeInformation.Abstract -> properties + is LocalTypeInformation.AnInterface -> properties + is LocalTypeInformation.NonComposable -> properties + is LocalTypeInformation.Opaque -> expand.propertiesOrEmptyMap + else -> emptyMap() + } + + /** + * Get the list of interfaces, for all types that have them, or an empty list otherwise. + */ + val interfacesOrEmptyList: List get() = when(this) { + is LocalTypeInformation.Composable -> interfaces + is LocalTypeInformation.Abstract -> interfaces + is LocalTypeInformation.AnInterface -> interfaces + is LocalTypeInformation.NonComposable -> interfaces + else -> emptyList() + } + /** * Obtain a multi-line, recursively-indented representation of this type information. * @@ -101,11 +135,10 @@ sealed class LocalTypeInformation { */ data class Cycle( override val observedType: Type, - override val typeIdentifier: TypeIdentifier, - private val _follow: () -> LocalTypeInformation) : LocalTypeInformation() { - val follow: LocalTypeInformation get() = _follow() + override val typeIdentifier: TypeIdentifier) : LocalTypeInformation() { + lateinit var follow: LocalTypeInformation - // Custom equals / hashcode because otherwise the "follow" lambda makes equality harder to reason about. + // Custom equals / hashcode omitting "follow" override fun equals(other: Any?): Boolean = other is Cycle && other.observedType == observedType && @@ -121,7 +154,10 @@ sealed class LocalTypeInformation { */ data class Opaque(override val observedType: Class<*>, override val typeIdentifier: TypeIdentifier, private val _expand: () -> LocalTypeInformation) : LocalTypeInformation() { - val expand: LocalTypeInformation get() = _expand() + /** + * In some rare cases, e.g. during Exception serialisation, we may want to "look inside" an opaque type. + */ + val expand: LocalTypeInformation by lazy { _expand() } // Custom equals / hashcode because otherwise the "expand" lambda makes equality harder to reason about. override fun equals(other: Any?): Boolean = @@ -202,6 +238,7 @@ sealed class LocalTypeInformation { * * @param constructor [LocalConstructorInformation] for the constructor used when building instances of this type * out of dictionaries of typed values. + * @param evolutionConstructors Evolution constructors in ascending version order. * @param properties [LocalPropertyInformation] for the properties of the interface. * @param superclass [LocalTypeInformation] for the superclass of the underlying class of this type. * @param interfaces [LocalTypeInformation] for the interfaces extended by this interface. @@ -211,7 +248,7 @@ sealed class LocalTypeInformation { override val observedType: Type, override val typeIdentifier: TypeIdentifier, val constructor: LocalConstructorInformation, - val evolverConstructors: List, + val evolutionConstructors: List, val properties: Map, val superclass: LocalTypeInformation, val interfaces: List, @@ -312,7 +349,7 @@ data class LocalConstructorInformation( * Represents information about a constructor that is specifically to be used for evolution, and is potentially matched * with a different set of properties to the regular constructor. */ -data class EvolverConstructorInformation( +data class EvolutionConstructorInformation( val constructor: LocalConstructorInformation, val properties: Map) @@ -330,16 +367,16 @@ private data class LocalTypeInformationPrettyPrinter(private val simplifyClassNa with(typeInformation) { when (this) { is LocalTypeInformation.Abstract -> - typeIdentifier.prettyPrint() + + typeIdentifier.prettyPrint(simplifyClassNames) + printInheritsFrom(interfaces, superclass) + indentAnd { printProperties(properties) } is LocalTypeInformation.AnInterface -> - typeIdentifier.prettyPrint() + printInheritsFrom(interfaces) - is LocalTypeInformation.Composable -> typeIdentifier.prettyPrint() + + typeIdentifier.prettyPrint(simplifyClassNames) + printInheritsFrom(interfaces) + is LocalTypeInformation.Composable -> typeIdentifier.prettyPrint(simplifyClassNames) + printConstructor(constructor) + printInheritsFrom(interfaces, superclass) + indentAnd { printProperties(properties) } - else -> typeIdentifier.prettyPrint() + else -> typeIdentifier.prettyPrint(simplifyClassNames) } } @@ -366,7 +403,7 @@ private data class LocalTypeInformationPrettyPrinter(private val simplifyClassNa " ".repeat(indent) + key + (if(!value.isMandatory) " (optional)" else "") + (if (value.isCalculated) " (calculated)" else "") + - ": " + value.type.prettyPrint(simplifyClassNames) + ": " + prettyPrint(value.type) private inline fun indentAnd(block: LocalTypeInformationPrettyPrinter.() -> String) = copy(indent = indent + 1).block() diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformationBuilder.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformationBuilder.kt index 88df121a43..5e596d6465 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformationBuilder.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/model/LocalTypeInformationBuilder.kt @@ -32,7 +32,10 @@ import kotlin.reflect.jvm.javaType * this is not a [MutableSet], as we want to be able to backtrack while traversing through the graph of related types, and * will find it useful to revert to earlier states of knowledge about which types have been visited on a given branch. */ -internal data class LocalTypeInformationBuilder(val lookup: LocalTypeLookup, val resolutionContext: Type? = null, val visited: Set = emptySet()) { +internal data class LocalTypeInformationBuilder(val lookup: LocalTypeLookup, + val resolutionContext: Type? = null, + val visited: Set = emptySet(), + val cycles: MutableList = mutableListOf()) { companion object { private val logger = contextLogger() @@ -42,9 +45,7 @@ internal data class LocalTypeInformationBuilder(val lookup: LocalTypeLookup, val * Recursively build [LocalTypeInformation] for the given [Type] and [TypeIdentifier] */ fun build(type: Type, typeIdentifier: TypeIdentifier): LocalTypeInformation = - if (typeIdentifier in visited) LocalTypeInformation.Cycle(type, typeIdentifier) { - LocalTypeInformationBuilder(lookup, resolutionContext).build(type, typeIdentifier) - } + if (typeIdentifier in visited) LocalTypeInformation.Cycle(type, typeIdentifier).apply { cycles.add(this) } else lookup.findOrBuild(type, typeIdentifier) { isOpaque -> copy(visited = visited + typeIdentifier).buildIfNotFound(type, typeIdentifier, isOpaque) } @@ -184,13 +185,13 @@ internal data class LocalTypeInformationBuilder(val lookup: LocalTypeLookup, val interfaceInformation, typeParameterInformation) } - val evolverConstructors = evolverConstructors(type).map { ctor -> + val evolutionConstructors = evolutionConstructors(type).map { ctor -> val constructorInformation = buildConstructorInformation(type, ctor) - val evolverProperties = buildObjectProperties(rawType, constructorInformation) - EvolverConstructorInformation(constructorInformation, evolverProperties) + val evolutionProperties = buildObjectProperties(rawType, constructorInformation) + EvolutionConstructorInformation(constructorInformation, evolutionProperties) } - return LocalTypeInformation.Composable(type, typeIdentifier, constructorInformation, evolverConstructors, properties, + return LocalTypeInformation.Composable(type, typeIdentifier, constructorInformation, evolutionConstructors, properties, superclassInformation, interfaceInformation, typeParameterInformation) } @@ -395,7 +396,10 @@ private fun constructorForDeserialization(type: Type): KFunction? { } } -private fun evolverConstructors(type: Type): List> { +/** + * Obtain evolution constructors in ascending version order. + */ +private fun evolutionConstructors(type: Type): List> { val clazz = type.asClass() if (!clazz.isConcreteClass || clazz.isSynthetic) return emptyList() diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/model/RemoteTypeCarpenter.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/model/RemoteTypeCarpenter.kt index 75e4b9363d..e7b97cf45b 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/model/RemoteTypeCarpenter.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/model/RemoteTypeCarpenter.kt @@ -23,7 +23,10 @@ class SchemaBuildingRemoteTypeCarpenter(private val carpenter: ClassCarpenter): try { when (typeInformation) { is RemoteTypeInformation.AnInterface -> typeInformation.carpentInterface() - is RemoteTypeInformation.Composable -> typeInformation.carpentComposable() + is RemoteTypeInformation.Composable -> + // We cannot carpent parameterised types, and if the type is parameterised assume we are really here + // because a type parameter needed carpenting. + if (typeInformation.typeIdentifier !is TypeIdentifier.Parameterised) typeInformation.carpentComposable() is RemoteTypeInformation.AnEnum -> typeInformation.carpentEnum() else -> { } // Anything else, such as arrays, will be taken care of by the above @@ -31,7 +34,14 @@ class SchemaBuildingRemoteTypeCarpenter(private val carpenter: ClassCarpenter): } catch (e: ClassCarpenterException) { throw NotSerializableException("${typeInformation.typeIdentifier.name}: ${e.message}") } - return typeInformation.typeIdentifier.getLocalType(classLoader) + + return try { + typeInformation.typeIdentifier.getLocalType(classLoader) + } catch (e: ClassNotFoundException) { + // This might happen if we've been asked to carpent up a parameterised type, and it's the rawtype itself + // rather than any of its type parameters that were missing. + throw NotSerializableException("Could not carpent ${typeInformation.typeIdentifier.prettyPrint(false)}") + } } private val RemoteTypeInformation.erasedLocalClass get() = typeIdentifier.getLocalType(classLoader).asClass() diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/model/RemoteTypeInformation.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/model/RemoteTypeInformation.kt index 93c0e72307..7cdae243a8 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/model/RemoteTypeInformation.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/model/RemoteTypeInformation.kt @@ -1,9 +1,5 @@ package net.corda.serialization.internal.model -import net.corda.serialization.internal.amqp.Transform -import net.corda.serialization.internal.amqp.TransformTypes -import java.util.* - typealias TypeDescriptor = String /** @@ -88,9 +84,9 @@ sealed class RemoteTypeInformation { /** * The [RemoteTypeInformation] emitted if we hit a cycle while traversing the graph of related types. */ - data class Cycle(override val typeIdentifier: TypeIdentifier, private val _follow: () -> RemoteTypeInformation) : RemoteTypeInformation() { - override val typeDescriptor = typeIdentifier.name - val follow: RemoteTypeInformation get() = _follow() + data class Cycle(override val typeIdentifier: TypeIdentifier) : RemoteTypeInformation() { + override val typeDescriptor by lazy { follow.typeDescriptor } + lateinit var follow: RemoteTypeInformation override fun equals(other: Any?): Boolean = other is Cycle && other.typeIdentifier == typeIdentifier override fun hashCode(): Int = typeIdentifier.hashCode() @@ -176,14 +172,14 @@ private data class RemoteTypeInformationPrettyPrinter(private val simplifyClassN } private fun printProperties(properties: Map) = - properties.entries.sortedBy { it.key }.joinToString("\n", "\n", "") { + properties.entries.joinToString("\n", "\n", "") { it.prettyPrint() } private fun Map.Entry.prettyPrint(): String = " ".repeat(indent) + key + (if(!value.isMandatory) " (optional)" else "") + - ": " + value.type.prettyPrint(simplifyClassNames) + ": " + prettyPrint(value.type) } data class EnumTransforms(val defaults: Map, val renames: Map) { diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/model/TypeModellingFingerPrinter.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/model/TypeModellingFingerPrinter.kt new file mode 100644 index 0000000000..8afebf0c15 --- /dev/null +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/model/TypeModellingFingerPrinter.kt @@ -0,0 +1,233 @@ +package net.corda.serialization.internal.model + +import com.google.common.hash.Hashing +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.toBase64 +import net.corda.serialization.internal.amqp.* +import java.io.NotSerializableException + +/** + * A fingerprinter that fingerprints [LocalTypeInformation]. + */ +interface FingerPrinter { + /** + * Traverse the provided [LocalTypeInformation] graph and emit a short fingerprint string uniquely representing + * the shape of that graph. + * + * @param typeInformation The [LocalTypeInformation] to fingerprint. + */ + fun fingerprint(typeInformation: LocalTypeInformation): String +} + +/** + * A [FingerPrinter] that consults a [CustomTypeDescriptorLookup] to obtain type descriptors for + * types that do not need to be traversed to calculate their fingerprint information. (Usually these will be the type + * descriptors supplied by custom serializers). + * + * @param customTypeDescriptorLookup The [CustomTypeDescriptorLookup] to use to obtain custom type descriptors for + * selected types. + */ +class TypeModellingFingerPrinter( + private val customTypeDescriptorLookup: CustomSerializerRegistry, + private val debugEnabled: Boolean = false) : FingerPrinter { + + private val cache: MutableMap = DefaultCacheProvider.createCache() + + override fun fingerprint(typeInformation: LocalTypeInformation): String = + cache.computeIfAbsent(typeInformation.typeIdentifier) { + FingerPrintingState( + customTypeDescriptorLookup, + FingerprintWriter(debugEnabled)).fingerprint(typeInformation) + } +} + +/** + * Wrapper for the [Hasher] we use to generate fingerprints, providing methods for writing various kinds of content + * into the hash. + */ +internal class FingerprintWriter(debugEnabled: Boolean) { + + companion object { + private const val ARRAY_HASH: String = "Array = true" + private const val ENUM_HASH: String = "Enum = true" + private const val ALREADY_SEEN_HASH: String = "Already seen = true" + private const val NULLABLE_HASH: String = "Nullable = true" + private const val NOT_NULLABLE_HASH: String = "Nullable = false" + private const val ANY_TYPE_HASH: String = "Any type = true" + + private val logger = contextLogger() + } + + private val debugBuffer: StringBuilder? = if (debugEnabled) StringBuilder() else null + private var hasher = Hashing.murmur3_128().newHasher() // FIXUP: remove dependency on Guava Hasher + + fun write(chars: CharSequence) = append(chars) + fun write(words: List) = append(words.joinToString()) + fun writeAlreadySeen() = append(ALREADY_SEEN_HASH) + fun writeEnum() = append(ENUM_HASH) + fun writeArray() = append(ARRAY_HASH) + fun writeNullable() = append(NULLABLE_HASH) + fun writeNotNullable() = append(NOT_NULLABLE_HASH) + fun writeAny() = append(ANY_TYPE_HASH) + + private fun append(chars: CharSequence) = apply { + debugBuffer?.append(chars) + hasher = hasher.putUnencodedChars(chars) + } + + val fingerprint: String by lazy { + val fingerprint = hasher.hash().asBytes().toBase64() + if (debugBuffer != null) logger.info("$fingerprint from $debugBuffer") + fingerprint + } +} + +/** + * Representation of the current state of fingerprinting, which keeps track of which types have already been visited + * during fingerprinting. + */ +private class FingerPrintingState( + private val customSerializerRegistry: CustomSerializerRegistry, + private val writer: FingerprintWriter) { + + companion object { + private var CHARACTER_TYPE = LocalTypeInformation.Atomic( + Character::class.java, + TypeIdentifier.forClass(Character::class.java)) + } + + private val typesSeen: MutableSet = mutableSetOf() + + /** + * Fingerprint the type recursively, and return the encoded fingerprint written into the hasher. + */ + fun fingerprint(type: LocalTypeInformation): String = + fingerprintType(type).writer.fingerprint + + // This method concatenates various elements of the types recursively as unencoded strings into the hasher, + // effectively creating a unique string for a type which we then hash in the calling function above. + private fun fingerprintType(type: LocalTypeInformation): FingerPrintingState = apply { + // Don't go round in circles. + when { + hasSeen(type.typeIdentifier) -> writer.writeAlreadySeen() + type is LocalTypeInformation.Cycle -> fingerprintType(type.follow) + else -> ifThrowsAppend({ type.observedType.typeName }, { + typesSeen.add(type.typeIdentifier) + fingerprintNewType(type) + }) + } + } + + // For a type we haven't seen before, determine the correct path depending on the type of type it is. + private fun fingerprintNewType(type: LocalTypeInformation) = apply { + when (type) { + is LocalTypeInformation.Cycle -> + throw IllegalStateException("Cyclic references must be dereferenced before fingerprinting") + is LocalTypeInformation.Unknown, + is LocalTypeInformation.Top -> writer.writeAny() + is LocalTypeInformation.AnArray -> { + fingerprintType(type.componentType) + writer.writeArray() + } + is LocalTypeInformation.ACollection -> fingerprintCollection(type) + is LocalTypeInformation.AMap -> fingerprintMap(type) + is LocalTypeInformation.Atomic -> fingerprintName(type) + is LocalTypeInformation.Opaque -> fingerprintOpaque(type) + is LocalTypeInformation.AnEnum -> fingerprintEnum(type) + is LocalTypeInformation.AnInterface -> fingerprintInterface(type) + is LocalTypeInformation.Abstract -> fingerprintAbstract(type) + is LocalTypeInformation.Singleton -> fingerprintName(type) + is LocalTypeInformation.Composable -> fingerprintComposable(type) + is LocalTypeInformation.NonComposable -> throw NotSerializableException( + "Attempted to fingerprint non-composable type ${type.typeIdentifier.prettyPrint(false)}") + } + } + + private fun fingerprintCollection(type: LocalTypeInformation.ACollection) { + fingerprintName(type) + fingerprintType(type.elementType) + } + + private fun fingerprintMap(type: LocalTypeInformation.AMap) { + fingerprintName(type) + fingerprintType(type.keyType) + fingerprintType(type.valueType) + } + + private fun fingerprintOpaque(type: LocalTypeInformation) = + fingerprintWithCustomSerializerOrElse(type) { + fingerprintName(type) + } + + private fun fingerprintInterface(type: LocalTypeInformation.AnInterface) = + fingerprintWithCustomSerializerOrElse(type) { + fingerprintName(type) + writer.writeAlreadySeen() // FIXUP: this replicates the behaviour of the old fingerprinter for compatibility reasons. + fingerprintInterfaces(type.interfaces) + fingerprintTypeParameters(type.typeParameters) + } + + private fun fingerprintAbstract(type: LocalTypeInformation.Abstract) = + fingerprintWithCustomSerializerOrElse(type) { + fingerprintName(type) + fingerprintProperties(type.properties) + fingerprintInterfaces(type.interfaces) + fingerprintTypeParameters(type.typeParameters) + } + + private fun fingerprintComposable(type: LocalTypeInformation.Composable) = + fingerprintWithCustomSerializerOrElse(type) { + fingerprintName(type) + fingerprintProperties(type.properties) + fingerprintInterfaces(type.interfaces) + fingerprintTypeParameters(type.typeParameters) + } + + private fun fingerprintName(type: LocalTypeInformation) { + val identifier = type.typeIdentifier + when (identifier) { + is TypeIdentifier.ArrayOf -> writer.write(identifier.componentType.name).writeArray() + else -> writer.write(identifier.name) + } + } + + private fun fingerprintTypeParameters(typeParameters: List) = + typeParameters.forEach { fingerprintType(it) } + + private fun fingerprintProperties(properties: Map) = + properties.asSequence().sortedBy { it.key }.forEach { (propertyName, propertyType) -> + val (neverMandatory, adjustedType) = adjustType(propertyType.type) + fingerprintType(adjustedType) + writer.write(propertyName) + if (propertyType.isMandatory && !neverMandatory) writer.writeNotNullable() else writer.writeNullable() + } + + // Compensate for the serialisation framework's forcing of char to Character + private fun adjustType(propertyType: LocalTypeInformation): Pair = + if (propertyType.typeIdentifier.name == "char") true to CHARACTER_TYPE else false to propertyType + + private fun fingerprintInterfaces(interfaces: List) = + interfaces.forEach { fingerprintType(it) } + + // ensures any change to the enum (adding constants) will trigger the need for evolution + private fun fingerprintEnum(type: LocalTypeInformation.AnEnum) { + writer.write(type.members).write(type.typeIdentifier.name).writeEnum() + } + + // Give any custom serializers loaded into the factory the chance to supply their own type-descriptors + private fun fingerprintWithCustomSerializerOrElse(type: LocalTypeInformation, defaultAction: () -> Unit) { + val customTypeDescriptor = customSerializerRegistry.findCustomSerializer(type.observedType.asClass(), type.observedType)?.typeDescriptor?.toString() + if (customTypeDescriptor != null) writer.write(customTypeDescriptor) + else defaultAction() + } + + // Test whether we are in a state in which we have already seen the given type. + // + // We don't include Example and Example where type is ? or T in this otherwise we + // generate different fingerprints for class Outer(val a: Inner) when serialising + // and deserializing (assuming deserialization is occurring in a factory that didn't + // serialise the object in the first place (and thus the cache lookup fails). This is also + // true of Any, where we need Example and Example to have the same fingerprint + private fun hasSeen(type: TypeIdentifier) = (type in typesSeen) + && (type != TypeIdentifier.UnknownType) +} diff --git a/serialization/src/test/java/net/corda/serialization/internal/amqp/JavaPrivatePropertyTests.java b/serialization/src/test/java/net/corda/serialization/internal/amqp/JavaPrivatePropertyTests.java index a68e14b572..a811807dcb 100644 --- a/serialization/src/test/java/net/corda/serialization/internal/amqp/JavaPrivatePropertyTests.java +++ b/serialization/src/test/java/net/corda/serialization/internal/amqp/JavaPrivatePropertyTests.java @@ -1,5 +1,6 @@ package net.corda.serialization.internal.amqp; +import net.corda.serialization.internal.amqp.testutils.TestDescriptorBasedSerializerRegistry; import net.corda.serialization.internal.amqp.testutils.TestSerializationContext; import org.junit.Test; @@ -133,8 +134,9 @@ public class JavaPrivatePropertyTests { } @Test - public void singlePrivateWithConstructor() throws NotSerializableException, NoSuchFieldException, IllegalAccessException { - SerializerFactory factory = testDefaultFactory(); + public void singlePrivateWithConstructor() throws NotSerializableException { + TestDescriptorBasedSerializerRegistry registry = new TestDescriptorBasedSerializerRegistry(); + SerializerFactory factory = testDefaultFactory(registry); SerializationOutput ser = new SerializationOutput(factory); DeserializationInput des = new DeserializationInput(factory); @@ -144,22 +146,14 @@ public class JavaPrivatePropertyTests { assertEquals (c.a, c2.a); - // - // Now ensure we actually got a private property serializer - // - Map> serializersByDescriptor = factory.getSerializersByDescriptor(); - - assertEquals(1, serializersByDescriptor.size()); - ObjectSerializer cSerializer = ((ObjectSerializer)serializersByDescriptor.values().toArray()[0]); - assertEquals(1, cSerializer.getPropertySerializers().getSerializationOrder().size()); - Object[] propertyReaders = cSerializer.getPropertySerializers().getSerializationOrder().toArray(); - assertTrue (((PropertyAccessor)propertyReaders[0]).getSerializer().getPropertyReader() instanceof PrivatePropertyReader); + assertEquals(1, registry.getContents().size()); } @Test public void singlePrivateWithConstructorAndGetter() - throws NotSerializableException, NoSuchFieldException, IllegalAccessException { - SerializerFactory factory = testDefaultFactory(); + throws NotSerializableException { + TestDescriptorBasedSerializerRegistry registry = new TestDescriptorBasedSerializerRegistry(); + SerializerFactory factory = testDefaultFactory(registry); SerializationOutput ser = new SerializationOutput(factory); DeserializationInput des = new DeserializationInput(factory); @@ -169,15 +163,6 @@ public class JavaPrivatePropertyTests { assertEquals (c.a, c2.a); - // - // Now ensure we actually got a private property serializer - // - Map> serializersByDescriptor = factory.getSerializersByDescriptor(); - - assertEquals(1, serializersByDescriptor.size()); - ObjectSerializer cSerializer = ((ObjectSerializer)serializersByDescriptor.values().toArray()[0]); - assertEquals(1, cSerializer.getPropertySerializers().getSerializationOrder().size()); - Object[] propertyReaders = cSerializer.getPropertySerializers().getSerializationOrder().toArray(); - assertTrue (((PropertyAccessor)propertyReaders[0]).getSerializer().getPropertyReader() instanceof PublicPropertyReader); + assertEquals(1, registry.getContents().size()); } } diff --git a/serialization/src/test/java/net/corda/serialization/internal/carpenter/JavaCalculatedValuesToClassCarpenterTest.java b/serialization/src/test/java/net/corda/serialization/internal/carpenter/JavaCalculatedValuesToClassCarpenterTest.java index 1171d7713d..d0e39bb214 100644 --- a/serialization/src/test/java/net/corda/serialization/internal/carpenter/JavaCalculatedValuesToClassCarpenterTest.java +++ b/serialization/src/test/java/net/corda/serialization/internal/carpenter/JavaCalculatedValuesToClassCarpenterTest.java @@ -7,6 +7,8 @@ import net.corda.core.serialization.SerializedBytes; import net.corda.serialization.internal.AllWhitelist; import net.corda.serialization.internal.amqp.*; import net.corda.serialization.internal.amqp.Schema; +import net.corda.serialization.internal.model.RemoteTypeInformation; +import net.corda.serialization.internal.model.TypeIdentifier; import net.corda.testing.core.SerializationEnvironmentRule; import org.junit.Before; import org.junit.Rule; @@ -66,42 +68,23 @@ public class JavaCalculatedValuesToClassCarpenterTest extends AmqpCarpenterBase ObjectAndEnvelope objAndEnv = new DeserializationInput(factory) .deserializeAndReturnEnvelope(serialized, C.class, context); - C amqpObj = objAndEnv.getObj(); - Schema schema = objAndEnv.getEnvelope().getSchema(); - - assertEquals(2, amqpObj.getI()); - assertEquals("4", amqpObj.getSquared()); - assertEquals(2, schema.getTypes().size()); - assertTrue(schema.getTypes().get(0) instanceof CompositeType); - - CompositeType concrete = (CompositeType) schema.getTypes().get(0); - assertEquals(3, concrete.getFields().size()); - assertEquals("doubled", concrete.getFields().get(0).getName()); - assertEquals("int", concrete.getFields().get(0).getType()); - assertEquals("i", concrete.getFields().get(1).getName()); - assertEquals("int", concrete.getFields().get(1).getType()); - assertEquals("squared", concrete.getFields().get(2).getName()); - assertEquals("string", concrete.getFields().get(2).getType()); - - assertEquals(0, AMQPSchemaExtensions.carpenterSchema(schema, ClassLoader.getSystemClassLoader()).getSize()); - Schema mangledSchema = ClassCarpenterTestUtilsKt.mangleNames(schema, singletonList(C.class.getTypeName())); - CarpenterMetaSchema l2 = AMQPSchemaExtensions.carpenterSchema(mangledSchema, ClassLoader.getSystemClassLoader()); - String mangledClassName = ClassCarpenterTestUtilsKt.mangleName(C.class.getTypeName()); - - assertEquals(1, l2.getSize()); - net.corda.serialization.internal.carpenter.Schema carpenterSchema = l2.getCarpenterSchemas().stream() - .filter(s -> s.getName().equals(mangledClassName)) + TypeIdentifier typeToMangle = TypeIdentifier.Companion.forClass(C.class); + Envelope env = objAndEnv.getEnvelope(); + RemoteTypeInformation typeInformation = getTypeInformation(env).values().stream() + .filter(it -> it.getTypeIdentifier().equals(typeToMangle)) .findFirst() - .orElseThrow(() -> new IllegalStateException("No schema found for mangled class name " + mangledClassName)); + .orElseThrow(IllegalStateException::new); - Class pinochio = new ClassCarpenterImpl(AllWhitelist.INSTANCE).build(carpenterSchema); + RemoteTypeInformation renamed = rename(typeInformation, typeToMangle, mangle(typeToMangle)); + + Class pinochio = load(renamed); Object p = pinochio.getConstructors()[0].newInstance(4, 2, "4"); - assertEquals(pinochio.getMethod("getI").invoke(p), amqpObj.getI()); - assertEquals(pinochio.getMethod("getSquared").invoke(p), amqpObj.getSquared()); - assertEquals(pinochio.getMethod("getDoubled").invoke(p), amqpObj.getDoubled()); + assertEquals(2, pinochio.getMethod("getI").invoke(p)); + assertEquals("4", pinochio.getMethod("getSquared").invoke(p)); + assertEquals(4, pinochio.getMethod("getDoubled").invoke(p)); Parent upcast = (Parent) p; - assertEquals(upcast.getDoubled(), amqpObj.getDoubled()); + assertEquals(4, upcast.getDoubled()); } } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/ListsSerializationTest.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/ListsSerializationTest.kt index 0164664800..55c0254095 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/ListsSerializationTest.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/ListsSerializationTest.kt @@ -88,7 +88,7 @@ class ListsSerializationTest { payload.add(2) val wrongPayloadType = WrongPayloadType(payload) Assertions.assertThatThrownBy { wrongPayloadType.serialize() } - .isInstanceOf(NotSerializableException::class.java).hasMessageContaining("Cannot derive collection type for declaredType") + .isInstanceOf(NotSerializableException::class.java).hasMessageContaining("Cannot derive collection type for declared type") } @CordaSerializable @@ -107,7 +107,9 @@ class ListsSerializationTest { val container = CovariantContainer(payload) fun verifyEnvelopeBody(envelope: Envelope) { - envelope.schema.types.single { typeNotation -> typeNotation.name == java.util.List::class.java.name + "" } + envelope.schema.types.single { typeNotation -> + typeNotation.name == "java.util.List<${Parent::class.java.name}>" + } } assertEqualAfterRoundTripSerialization(container, { bytes -> verifyEnvelope(bytes, ::verifyEnvelopeBody) }) diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/AbstractAMQPSerializationSchemeTest.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/AbstractAMQPSerializationSchemeTest.kt index 8c38cbeb26..eaaeffb043 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/AbstractAMQPSerializationSchemeTest.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/AbstractAMQPSerializationSchemeTest.kt @@ -37,7 +37,7 @@ class AbstractAMQPSerializationSchemeTest { null) - val factory = TestSerializerFactory(TESTING_CONTEXT.whitelist, TESTING_CONTEXT.deserializationClassLoader) + val factory = SerializerFactoryBuilder.build(TESTING_CONTEXT.whitelist, TESTING_CONTEXT.deserializationClassLoader) val maxFactories = 512 val backingMap = AccessOrderLinkedHashMap, SerializerFactory>({ maxFactories }) val scheme = object : AbstractAMQPSerializationScheme(emptySet(), backingMap, createSerializerFactoryFactory()) { @@ -55,7 +55,6 @@ class AbstractAMQPSerializationSchemeTest { } - IntStream.range(0, 2048).parallel().forEach { val context = if (ThreadLocalRandom.current().nextBoolean()) { genesisContext.withClassLoader(URLClassLoader(emptyArray())) diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/CorDappSerializerTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/CorDappSerializerTests.kt index d38187f6b7..d8c263d113 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/CorDappSerializerTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/CorDappSerializerTests.kt @@ -13,16 +13,13 @@ import kotlin.test.assertEquals class CorDappSerializerTests { data class NeedsProxy(val a: String) - private fun proxyFactory(serializers: List>): SerializerFactory { - val factory = SerializerFactoryBuilder.build(AllWhitelist, - ClassCarpenterImpl(AllWhitelist, ClassLoader.getSystemClassLoader()), - DefaultEvolutionSerializerProvider) - + private fun proxyFactory( + serializers: List> + ) = SerializerFactoryBuilder.build(AllWhitelist, + ClassCarpenterImpl(AllWhitelist, ClassLoader.getSystemClassLoader())).apply { serializers.forEach { - factory.registerExternal(CorDappCustomSerializer(it, factory)) + registerExternal(CorDappCustomSerializer(it, this)) } - - return factory } class NeedsProxyProxySerializer : SerializationCustomSerializer { diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeNeedingCarpentryOfEnumsTest.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeNeedingCarpentryOfEnumsTest.kt index 5520d3ea6c..7fc19130e5 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeNeedingCarpentryOfEnumsTest.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeNeedingCarpentryOfEnumsTest.kt @@ -20,15 +20,15 @@ class DeserializeNeedingCarpentryOfEnumsTest : AmqpCarpenterBase(AllWhitelist) { // Setup the test // val setupFactory = testDefaultFactoryNoEvolution() - + val classCarpenter = ClassCarpenterImpl(AllWhitelist, ClassLoader.getSystemClassLoader()) val enumConstants = listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF", "GGG", "HHH", "III", "JJJ").associateBy({ it }, { EnumField() }) // create the enum - val testEnumType = setupFactory.classCarpenter.build(EnumSchema("test.testEnumType", enumConstants)) + val testEnumType = classCarpenter.build(EnumSchema("test.testEnumType", enumConstants)) // create the class that has that enum as an element - val testClassType = setupFactory.classCarpenter.build(ClassSchema("test.testClassType", + val testClassType = classCarpenter.build(ClassSchema("test.testClassType", mapOf("a" to NonNullableField(testEnumType)))) // create an instance of the class we can then serialise @@ -59,16 +59,16 @@ class DeserializeNeedingCarpentryOfEnumsTest : AmqpCarpenterBase(AllWhitelist) { // Setup the test // val setupFactory = testDefaultFactoryNoEvolution() - + val classCarpenter = ClassCarpenterImpl(AllWhitelist, ClassLoader.getSystemClassLoader()) val enumConstants = listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF", "GGG", "HHH", "III", "JJJ").associateBy({ it }, { EnumField() }) // create the enum - val testEnumType1 = setupFactory.classCarpenter.build(EnumSchema("test.testEnumType1", enumConstants)) - val testEnumType2 = setupFactory.classCarpenter.build(EnumSchema("test.testEnumType2", enumConstants)) + val testEnumType1 = classCarpenter.build(EnumSchema("test.testEnumType1", enumConstants)) + val testEnumType2 = classCarpenter.build(EnumSchema("test.testEnumType2", enumConstants)) // create the class that has that enum as an element - val testClassType = setupFactory.classCarpenter.build(ClassSchema("test.testClassType", + val testClassType = classCarpenter.build(ClassSchema("test.testClassType", mapOf( "a" to NonNullableField(testEnumType1), "b" to NonNullableField(testEnumType2), diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeNeedingCarpentrySimpleTypesTest.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeNeedingCarpentrySimpleTypesTest.kt index d8e3fc57df..5b00c4bf30 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeNeedingCarpentrySimpleTypesTest.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeNeedingCarpentrySimpleTypesTest.kt @@ -441,7 +441,4 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase(AllWhitelis assertEquals(0b1010.toByte(), deserializedObj::class.java.getMethod("getByteB").invoke(deserializedObj)) assertEquals(null, deserializedObj::class.java.getMethod("getByteC").invoke(deserializedObj)) } -} - - - +} \ No newline at end of file diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeSimpleTypesTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeSimpleTypesTests.kt index 6dc46436b9..e419963a2b 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeSimpleTypesTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeSimpleTypesTests.kt @@ -74,7 +74,7 @@ class DeserializeSimpleTypesTests { val ia = IA(arrayOf(1, 2, 3)) assertEquals("class [Ljava.lang.Integer;", ia.ia::class.java.toString()) - assertEquals(SerializerFactory.nameForType(ia.ia::class.java), "int[]") + assertEquals(AMQPTypeIdentifiers.nameForType(ia.ia::class.java), "int[]") val serialisedIA = TestSerializationOutput(VERBOSE, sf1).serialize(ia) val deserializedIA = DeserializationInput(sf1).deserialize(serialisedIA) @@ -93,7 +93,7 @@ class DeserializeSimpleTypesTests { val ia = IA(arrayOf(Integer(1), Integer(2), Integer(3))) assertEquals("class [Ljava.lang.Integer;", ia.ia::class.java.toString()) - assertEquals(SerializerFactory.nameForType(ia.ia::class.java), "int[]") + assertEquals(AMQPTypeIdentifiers.nameForType(ia.ia::class.java), "int[]") val serialisedIA = TestSerializationOutput(VERBOSE, sf1).serialize(ia) val deserializedIA = DeserializationInput(sf1).deserialize(serialisedIA) @@ -116,7 +116,7 @@ class DeserializeSimpleTypesTests { val ia = IA(v) assertEquals("class [I", ia.ia::class.java.toString()) - assertEquals(SerializerFactory.nameForType(ia.ia::class.java), "int[p]") + assertEquals(AMQPTypeIdentifiers.nameForType(ia.ia::class.java), "int[p]") val serialisedIA = TestSerializationOutput(VERBOSE, sf1).serialize(ia) val deserializedIA = DeserializationInput(sf1).deserialize(serialisedIA) @@ -134,7 +134,7 @@ class DeserializeSimpleTypesTests { val c = C(arrayOf('a', 'b', 'c')) assertEquals("class [Ljava.lang.Character;", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "char[]") + assertEquals(AMQPTypeIdentifiers.nameForType(c.c::class.java), "char[]") val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -154,7 +154,7 @@ class DeserializeSimpleTypesTests { val c = C(v) assertEquals("class [C", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "char[p]") + assertEquals(AMQPTypeIdentifiers.nameForType(c.c::class.java), "char[p]") val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) var deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -183,7 +183,7 @@ class DeserializeSimpleTypesTests { val c = C(arrayOf(true, false, false, true)) assertEquals("class [Ljava.lang.Boolean;", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "boolean[]") + assertEquals(AMQPTypeIdentifiers.nameForType(c.c::class.java), "boolean[]") val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -203,7 +203,7 @@ class DeserializeSimpleTypesTests { c.c[0] = true; c.c[1] = false; c.c[2] = false; c.c[3] = true assertEquals("class [Z", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "boolean[p]") + assertEquals(AMQPTypeIdentifiers.nameForType(c.c::class.java), "boolean[p]") val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -222,7 +222,7 @@ class DeserializeSimpleTypesTests { val c = C(arrayOf(0b0001, 0b0101, 0b1111)) assertEquals("class [Ljava.lang.Byte;", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "byte[]") + assertEquals(AMQPTypeIdentifiers.nameForType(c.c::class.java), "byte[]") val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -241,7 +241,7 @@ class DeserializeSimpleTypesTests { c.c[0] = 0b0001; c.c[1] = 0b0101; c.c[2] = 0b1111 assertEquals("class [B", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "binary") + assertEquals("binary", AMQPTypeIdentifiers.nameForType(c.c::class.java)) val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -267,7 +267,7 @@ class DeserializeSimpleTypesTests { val c = C(arrayOf(1, 2, 3)) assertEquals("class [Ljava.lang.Short;", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "short[]") + assertEquals(AMQPTypeIdentifiers.nameForType(c.c::class.java), "short[]") val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -286,7 +286,7 @@ class DeserializeSimpleTypesTests { c.c[0] = 1; c.c[1] = 2; c.c[2] = 5 assertEquals("class [S", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "short[p]") + assertEquals(AMQPTypeIdentifiers.nameForType(c.c::class.java), "short[p]") val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -304,7 +304,7 @@ class DeserializeSimpleTypesTests { val c = C(arrayOf(2147483650, -2147483800, 10)) assertEquals("class [Ljava.lang.Long;", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "long[]") + assertEquals(AMQPTypeIdentifiers.nameForType(c.c::class.java), "long[]") val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -323,7 +323,7 @@ class DeserializeSimpleTypesTests { c.c[0] = 2147483650; c.c[1] = -2147483800; c.c[2] = 10 assertEquals("class [J", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "long[p]") + assertEquals(AMQPTypeIdentifiers.nameForType(c.c::class.java), "long[p]") val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -341,7 +341,7 @@ class DeserializeSimpleTypesTests { val c = C(arrayOf(10F, 100.023232F, -1455.433400F)) assertEquals("class [Ljava.lang.Float;", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "float[]") + assertEquals("float[]", AMQPTypeIdentifiers.nameForType(c.c::class.java)) val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -360,7 +360,7 @@ class DeserializeSimpleTypesTests { c.c[0] = 10F; c.c[1] = 100.023232F; c.c[2] = -1455.433400F assertEquals("class [F", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "float[p]") + assertEquals(AMQPTypeIdentifiers.nameForType(c.c::class.java), "float[p]") val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -378,7 +378,7 @@ class DeserializeSimpleTypesTests { val c = C(arrayOf(10.0, 100.2, -1455.2)) assertEquals("class [Ljava.lang.Double;", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "double[]") + assertEquals(AMQPTypeIdentifiers.nameForType(c.c::class.java), "double[]") val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) @@ -397,7 +397,7 @@ class DeserializeSimpleTypesTests { c.c[0] = 10.0; c.c[1] = 100.2; c.c[2] = -1455.2 assertEquals("class [D", c.c::class.java.toString()) - assertEquals(SerializerFactory.nameForType(c.c::class.java), "double[p]") + assertEquals(AMQPTypeIdentifiers.nameForType(c.c::class.java), "double[p]") val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializedParameterizedTypeTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializedParameterizedTypeTests.kt deleted file mode 100644 index e2e718cc8d..0000000000 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializedParameterizedTypeTests.kt +++ /dev/null @@ -1,105 +0,0 @@ -package net.corda.serialization.internal.amqp - -import org.junit.Test -import java.io.NotSerializableException -import kotlin.test.assertEquals - -class DeserializedParameterizedTypeTests { - private fun normalise(string: String): String { - return string.replace(" ", "") - } - - private fun verify(typeName: String) { - val type = DeserializedParameterizedType.make(typeName) - assertEquals(normalise(type.typeName), normalise(typeName)) - } - - @Test - fun `test nested`() { - verify(" java.util.Map < java.util.Map< java.lang.String, java.lang.Integer >, java.util.Map < java.lang.Long , java.lang.String > >") - } - - @Test - fun `test simple`() { - verify("java.util.List") - } - - @Test - fun `test multiple args`() { - verify("java.util.Map") - } - - @Test - fun `test trailing whitespace`() { - verify("java.util.Map ") - } - - @Test - fun `test list of commands`() { - verify("java.util.List>") - } - - @Test(expected = NotSerializableException::class) - fun `test trailing text`() { - verify("java.util.Mapfoo") - } - - @Test(expected = NotSerializableException::class) - fun `test trailing comma`() { - verify("java.util.Map") - } - - @Test(expected = NotSerializableException::class) - fun `test leading comma`() { - verify("java.util.Map<,java.lang.String, java.lang.Integer>") - } - - @Test(expected = NotSerializableException::class) - fun `test middle comma`() { - verify("java.util.Map<,java.lang.String,, java.lang.Integer>") - } - - @Test(expected = NotSerializableException::class) - fun `test trailing close`() { - verify("java.util.Map>") - } - - @Test(expected = NotSerializableException::class) - fun `test empty params`() { - verify("java.util.Map<>") - } - - @Test(expected = NotSerializableException::class) - fun `test mid whitespace`() { - verify("java.u til.List") - } - - @Test(expected = NotSerializableException::class) - fun `test mid whitespace2`() { - verify("java.util.List") - } - - @Test(expected = NotSerializableException::class) - fun `test wrong number of parameters`() { - verify("java.util.List") - } - - @Test - fun `test no parameters`() { - verify("java.lang.String") - } - - @Test(expected = NotSerializableException::class) - fun `test parameters on non-generic type`() { - verify("java.lang.String") - } - - @Test(expected = NotSerializableException::class) - fun `test excessive nesting`() { - var nested = "java.lang.Integer" - for (i in 1..DeserializedParameterizedType.MAX_DEPTH) { - nested = "java.util.List<$nested>" - } - verify(nested) - } -} \ No newline at end of file diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EnumEvolvabilityTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EnumEvolvabilityTests.kt index 03978266b6..ef0c6d38c4 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EnumEvolvabilityTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EnumEvolvabilityTests.kt @@ -392,27 +392,10 @@ class EnumEvolvabilityTests { data class C1(val annotatedEnum: AnnotatedEnumOnce) val sf = testDefaultFactory() - val f = sf.javaClass.getDeclaredField("transformsCache") - f.isAccessible = true - - @Suppress("UNCHECKED_CAST") - val transformsCache = f.get(sf) as ConcurrentHashMap>> - - assertEquals(0, transformsCache.size) val sb1 = TestSerializationOutput(VERBOSE, sf).serializeAndReturnSchema(C1(AnnotatedEnumOnce.D)) - - assertEquals(2, transformsCache.size) - assertTrue(transformsCache.containsKey(C1::class.java.name)) - assertTrue(transformsCache.containsKey(AnnotatedEnumOnce::class.java.name)) - val sb2 = TestSerializationOutput(VERBOSE, sf).serializeAndReturnSchema(C2(AnnotatedEnumOnce.D)) - assertEquals(3, transformsCache.size) - assertTrue(transformsCache.containsKey(C1::class.java.name)) - assertTrue(transformsCache.containsKey(C2::class.java.name)) - assertTrue(transformsCache.containsKey(AnnotatedEnumOnce::class.java.name)) - assertEquals(sb1.transformsSchema.types[AnnotatedEnumOnce::class.java.name], sb2.transformsSchema.types[AnnotatedEnumOnce::class.java.name]) } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerFactoryTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerFactoryTests.kt new file mode 100644 index 0000000000..0dda90607b --- /dev/null +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerFactoryTests.kt @@ -0,0 +1,55 @@ +package net.corda.serialization.internal.amqp + +import net.corda.serialization.internal.amqp.testutils.deserializeAndReturnEnvelope +import net.corda.serialization.internal.amqp.testutils.serialize +import net.corda.serialization.internal.amqp.testutils.testDefaultFactory +import net.corda.serialization.internal.model.RemoteTypeInformation +import net.corda.serialization.internal.model.TypeIdentifier +import org.junit.Test +import kotlin.test.assertFailsWith +import kotlin.test.assertNotNull +import kotlin.test.assertNull + +class EvolutionSerializerFactoryTests { + + private val factory = testDefaultFactory() + + @Test + fun preservesDataWhenFlagSet() { + val nonStrictEvolutionSerializerFactory = DefaultEvolutionSerializerFactory( + factory, + ClassLoader.getSystemClassLoader(), + false) + + val strictEvolutionSerializerFactory = DefaultEvolutionSerializerFactory( + factory, + ClassLoader.getSystemClassLoader(), + true) + + @Suppress("unused") + class C(val importantFieldA: Int) + val (_, env) = DeserializationInput(factory).deserializeAndReturnEnvelope( + SerializationOutput(factory).serialize(C(1))) + + val remoteTypeInformation = AMQPRemoteTypeModel().interpret(SerializationSchemas(env.schema, env.transformsSchema)) + .values.find { it.typeIdentifier == TypeIdentifier.forClass(C::class.java) } + as RemoteTypeInformation.Composable + + val withAddedField = remoteTypeInformation.copy(properties = remoteTypeInformation.properties.plus( + "importantFieldB" to remoteTypeInformation.properties["importantFieldA"]!!)) + + val localTypeInformation = factory.getTypeInformation(C::class.java) + + // No evolution required with original fields. + assertNull(strictEvolutionSerializerFactory.getEvolutionSerializer(remoteTypeInformation, localTypeInformation)) + + // Returns an evolution serializer if the fields have changed. + assertNotNull(nonStrictEvolutionSerializerFactory.getEvolutionSerializer(withAddedField, localTypeInformation)) + + // Fails in strict mode if the remote type information includes a field not included in the local type. + assertFailsWith { + strictEvolutionSerializerFactory.getEvolutionSerializer(withAddedField, localTypeInformation) + } + } + +} \ No newline at end of file diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerProviderTesting.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerProviderTesting.kt deleted file mode 100644 index 7239e6a467..0000000000 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EvolutionSerializerProviderTesting.kt +++ /dev/null @@ -1,22 +0,0 @@ -package net.corda.serialization.internal.amqp - -import java.io.NotSerializableException - -/** - * An implementation of [EvolutionSerializerProvider] that disables all evolution within a - * [SerializerFactory]. This is most useful in testing where it is known that evolution should not be - * occurring and where bugs may be hidden by transparent invocation of an [EvolutionSerializer]. This - * prevents that by simply throwing an exception whenever such a serializer is requested. - */ -object FailIfEvolutionAttempted : EvolutionSerializerProvider { - override fun getEvolutionSerializer(factory: SerializerFactory, - typeNotation: TypeNotation, - newSerializer: AMQPSerializer, - schemas: SerializationSchemas): AMQPSerializer { - throw NotSerializableException("No evolution should be occurring\n" + - " ${typeNotation.name}\n" + - " ${typeNotation.descriptor.name}\n" + - " ${newSerializer.type.typeName}\n" + - " ${newSerializer.typeDescriptor}\n\n${schemas.schema}") - } -} diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EvolvabilityTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EvolvabilityTests.kt index e0109515d6..ac35801b27 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EvolvabilityTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/EvolvabilityTests.kt @@ -461,6 +461,17 @@ class EvolvabilityTests { assertEquals(oa, outer.a) assertEquals(ia, outer.b.a) assertEquals(null, outer.b.b) + + // Repeat, but receiving a message with the newer version of Inner + val newVersion = SerializationOutput(sf).serializeAndReturnSchema(Outer(oa, Inner(ia, "new value"))) + val model = AMQPRemoteTypeModel() + val remoteTypeInfo = model.interpret(SerializationSchemas(newVersion.schema, newVersion.transformsSchema)) + println(remoteTypeInfo) + + val newOuter = DeserializationInput(sf).deserialize(SerializedBytes(newVersion.obj.bytes)) + assertEquals(oa, newOuter.a) + assertEquals(ia, newOuter.b.a) + assertEquals("new value", newOuter.b.b) } @Test diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/FingerPrinterTesting.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/FingerPrinterTesting.kt index 1e5f7c3d82..cad38afa53 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/FingerPrinterTesting.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/FingerPrinterTesting.kt @@ -1,23 +1,25 @@ package net.corda.serialization.internal.amqp import org.junit.Test -import java.lang.reflect.Type import kotlin.test.assertEquals import net.corda.serialization.internal.AllWhitelist import net.corda.serialization.internal.amqp.testutils.TestSerializationOutput import net.corda.serialization.internal.amqp.testutils.serializeAndReturnSchema import net.corda.serialization.internal.carpenter.ClassCarpenterImpl +import net.corda.serialization.internal.model.ConfigurableLocalTypeModel +import net.corda.serialization.internal.model.LocalTypeInformation +import net.corda.serialization.internal.model.FingerPrinter class FingerPrinterTesting : FingerPrinter { private var index = 0 - private val cache = mutableMapOf() + private val cache = mutableMapOf() - override fun fingerprint(type: Type): String { - return cache.computeIfAbsent(type) { index++.toString() } + override fun fingerprint(typeInformation: LocalTypeInformation): String { + return cache.computeIfAbsent(typeInformation) { index++.toString() } } @Suppress("UNUSED") - fun changeFingerprint(type: Type) { + fun changeFingerprint(type: LocalTypeInformation) { cache.computeIfAbsent(type) { "" }.apply { index++.toString() } } } @@ -30,10 +32,14 @@ class FingerPrinterTestingTests { @Test fun testingTest() { val fpt = FingerPrinterTesting() - assertEquals("0", fpt.fingerprint(Integer::class.java)) - assertEquals("1", fpt.fingerprint(String::class.java)) - assertEquals("0", fpt.fingerprint(Integer::class.java)) - assertEquals("1", fpt.fingerprint(String::class.java)) + val descriptorBasedSerializerRegistry = DefaultDescriptorBasedSerializerRegistry() + val customSerializerRegistry: CustomSerializerRegistry = CachingCustomSerializerRegistry(descriptorBasedSerializerRegistry) + val typeModel = ConfigurableLocalTypeModel(WhitelistBasedTypeModelConfiguration(AllWhitelist, customSerializerRegistry)) + + assertEquals("0", fpt.fingerprint(typeModel.inspect(Integer::class.java))) + assertEquals("1", fpt.fingerprint(typeModel.inspect(String::class.java))) + assertEquals("0", fpt.fingerprint(typeModel.inspect(Integer::class.java))) + assertEquals("1", fpt.fingerprint(typeModel.inspect(String::class.java))) } @Test @@ -42,7 +48,7 @@ class FingerPrinterTestingTests { val factory = SerializerFactoryBuilder.build(AllWhitelist, ClassCarpenterImpl(AllWhitelist, ClassLoader.getSystemClassLoader()), - fingerPrinterProvider = { _ -> FingerPrinterTesting() }) + overrideFingerPrinter = FingerPrinterTesting()) val blob = TestSerializationOutput(VERBOSE, factory).serializeAndReturnSchema(C(1, 2L)) diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/GenericsTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/GenericsTests.kt index 544d467bd8..c85fc2a343 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/GenericsTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/GenericsTests.kt @@ -40,15 +40,6 @@ class GenericsTests { private fun BytesAndSchemas.printSchema() = if (VERBOSE) println("${this.schema}\n") else Unit - private fun MutableMap>.printKeyToType() { - if (!VERBOSE) return - - forEach { - println("Key = ${it.key} - ${it.value.type.typeName}") - } - println() - } - @Test fun twoDifferentTypesSameParameterizedOuter() { data class G(val a: A) @@ -57,12 +48,8 @@ class GenericsTests { val bytes1 = SerializationOutput(factory).serializeAndReturnSchema(G("hi")).apply { printSchema() } - factory.serializersByDescriptor.printKeyToType() - val bytes2 = SerializationOutput(factory).serializeAndReturnSchema(G(121)).apply { printSchema() } - factory.serializersByDescriptor.printKeyToType() - listOf(factory, testDefaultFactory()).forEach { f -> DeserializationInput(f).deserialize(bytes1.obj).apply { assertEquals("hi", this.a) } DeserializationInput(f).deserialize(bytes2.obj).apply { assertEquals(121, this.a) } @@ -94,15 +81,11 @@ class GenericsTests { val bytes = ser.serializeAndReturnSchema(G("hi")).apply { printSchema() } - factory.serializersByDescriptor.printKeyToType() - assertEquals("hi", DeserializationInput(factory).deserialize(bytes.obj).a) assertEquals("hi", DeserializationInput(altContextFactory).deserialize(bytes.obj).a) val bytes2 = ser.serializeAndReturnSchema(Wrapper(1, G("hi"))).apply { printSchema() } - factory.serializersByDescriptor.printKeyToType() - printSeparator() DeserializationInput(factory).deserialize(bytes2.obj).apply { @@ -161,21 +144,18 @@ class GenericsTests { ser.serialize(Wrapper(Container(InnerA(1)))).apply { factories.forEach { DeserializationInput(it).deserialize(this).apply { assertEquals(1, c.b.a_a) } - it.serializersByDescriptor.printKeyToType(); printSeparator() } } ser.serialize(Wrapper(Container(InnerB(1)))).apply { factories.forEach { DeserializationInput(it).deserialize(this).apply { assertEquals(1, c.b.a_b) } - it.serializersByDescriptor.printKeyToType(); printSeparator() } } ser.serialize(Wrapper(Container(InnerC("Ho ho ho")))).apply { factories.forEach { DeserializationInput(it).deserialize(this).apply { assertEquals("Ho ho ho", c.b.a_c) } - it.serializersByDescriptor.printKeyToType(); printSeparator() } } } @@ -217,7 +197,6 @@ class GenericsTests { ClassCarpenterImpl(AllWhitelist, ClassLoader.getSystemClassLoader()) )): SerializedBytes<*> { val bytes = SerializationOutput(factory).serializeAndReturnSchema(a) - factory.serializersByDescriptor.printKeyToType() bytes.printSchema() return bytes.obj } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/PrivatePropertyTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/PrivatePropertyTests.kt index b4012b5c00..43c248c360 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/PrivatePropertyTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/PrivatePropertyTests.kt @@ -3,19 +3,20 @@ package net.corda.serialization.internal.amqp import junit.framework.TestCase.assertTrue import junit.framework.TestCase.assertEquals import net.corda.core.serialization.ConstructorForDeserialization -import net.corda.serialization.internal.amqp.testutils.deserialize -import net.corda.serialization.internal.amqp.testutils.serializeAndReturnSchema -import net.corda.serialization.internal.amqp.testutils.serialize -import net.corda.serialization.internal.amqp.testutils.testDefaultFactoryNoEvolution +import net.corda.serialization.internal.amqp.testutils.* +import net.corda.serialization.internal.model.ConfigurableLocalTypeModel +import net.corda.serialization.internal.model.LocalPropertyInformation +import net.corda.serialization.internal.model.LocalTypeInformation import org.junit.Test -import org.apache.qpid.proton.amqp.Symbol import org.assertj.core.api.Assertions import java.io.NotSerializableException -import java.util.concurrent.ConcurrentHashMap import java.util.* class PrivatePropertyTests { - private val factory = testDefaultFactoryNoEvolution() + + private val registry = TestDescriptorBasedSerializerRegistry() + private val factory = testDefaultFactoryNoEvolution(registry) + val typeModel = ConfigurableLocalTypeModel(WhitelistBasedTypeModelConfiguration(factory.whitelist, factory)) @Test fun testWithOnePrivateProperty() { @@ -125,21 +126,13 @@ class PrivatePropertyTests { val schemaAndBlob = SerializationOutput(factory).serializeAndReturnSchema(c1) assertEquals(1, schemaAndBlob.schema.types.size) - val serializersByDescriptor = factory.serializersByDescriptor + val typeInformation = typeModel.inspect(C::class.java) + assertTrue(typeInformation is LocalTypeInformation.Composable) + typeInformation as LocalTypeInformation.Composable - val schemaDescriptor = schemaAndBlob.schema.types.first().descriptor.name - serializersByDescriptor.filterKeys { (it as Symbol) == schemaDescriptor }.values.apply { - assertEquals(1, this.size) - assertTrue(this.first() is ObjectSerializer) - val propertySerializers = (this.first() as ObjectSerializer).propertySerializers.serializationOrder.map { it.serializer } - assertEquals(2, propertySerializers.size) - // a was public so should have a synthesised getter - assertTrue(propertySerializers[0].propertyReader is PublicPropertyReader) - - // b is private and thus won't have teh getter so we'll have reverted - // to using reflection to remove the inaccessible property - assertTrue(propertySerializers[1].propertyReader is PrivatePropertyReader) - } + assertEquals(2, typeInformation.properties.size) + assertTrue(typeInformation.properties["a"] is LocalPropertyInformation.ConstructorPairedProperty) + assertTrue(typeInformation.properties["b"] is LocalPropertyInformation.PrivateConstructorPairedProperty) } @Test @@ -153,22 +146,14 @@ class PrivatePropertyTests { val schemaAndBlob = SerializationOutput(factory).serializeAndReturnSchema(c1) assertEquals(1, schemaAndBlob.schema.types.size) - val serializersByDescriptor = factory.serializersByDescriptor - val schemaDescriptor = schemaAndBlob.schema.types.first().descriptor.name - serializersByDescriptor.filterKeys { (it as Symbol) == schemaDescriptor }.values.apply { - assertEquals(1, this.size) - assertTrue(this.first() is ObjectSerializer) - val propertySerializers = (this.first() as ObjectSerializer).propertySerializers.serializationOrder.map { it.serializer } - assertEquals(2, propertySerializers.size) + val typeInformation = typeModel.inspect(C::class.java) + assertTrue(typeInformation is LocalTypeInformation.Composable) + typeInformation as LocalTypeInformation.Composable - // as before, a is public so we'll use the getter method - assertTrue(propertySerializers[0].propertyReader is PublicPropertyReader) - - // the getB() getter explicitly added means we should use the "normal" public - // method reader rather than the private oen - assertTrue(propertySerializers[1].propertyReader is PublicPropertyReader) - } + assertEquals(2, typeInformation.properties.size) + assertTrue(typeInformation.properties["a"] is LocalPropertyInformation.ConstructorPairedProperty) + assertTrue(typeInformation.properties["b"] is LocalPropertyInformation.ConstructorPairedProperty) } @Suppress("UNCHECKED_CAST") @@ -179,9 +164,8 @@ class PrivatePropertyTests { val c1 = Outer(Inner(1010101)) val output = SerializationOutput(factory).serializeAndReturnSchema(c1) - println (output.schema) - val serializersByDescriptor = factory.serializersByDescriptor + val serializersByDescriptor = registry.contents // Inner and Outer assertEquals(2, serializersByDescriptor.size) @@ -198,24 +182,13 @@ class PrivatePropertyTests { @Test fun allCapsProprtyNotPrivate() { data class C (val CCC: String) + val typeInformation = typeModel.inspect(C::class.java) - val output = SerializationOutput(factory).serializeAndReturnSchema(C("this is nice")) + assertTrue(typeInformation is LocalTypeInformation.Composable) + typeInformation as LocalTypeInformation.Composable - val serializersByDescriptor = factory.serializersByDescriptor - - val schemaDescriptor = output.schema.types.first().descriptor.name - serializersByDescriptor.filterKeys { (it as Symbol) == schemaDescriptor }.values.apply { - assertEquals(1, size) - - assertTrue(this.first() is ObjectSerializer) - val propertySerializers = (this.first() as ObjectSerializer).propertySerializers.serializationOrder.map { it.serializer } - - // CCC is the only property to be serialised - assertEquals(1, propertySerializers.size) - - // and despite being all caps it should still be a public getter - assertTrue(propertySerializers[0].propertyReader is PublicPropertyReader) - } + assertEquals(1, typeInformation.properties.size) + assertTrue(typeInformation.properties["CCC"] is LocalPropertyInformation.ConstructorPairedProperty) } } \ No newline at end of file diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationOutputTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationOutputTests.kt index 0ee30b0373..3adba6ac32 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationOutputTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationOutputTests.kt @@ -21,7 +21,6 @@ import net.corda.core.utilities.OpaqueBytes import net.corda.node.serialization.amqp.AMQPServerSerializationScheme import net.corda.nodeapi.internal.crypto.ContentSignerBuilder import net.corda.serialization.internal.* -import net.corda.serialization.internal.amqp.SerializerFactory.Companion.isPrimitive import net.corda.serialization.internal.amqp.testutils.* import net.corda.serialization.internal.carpenter.ClassCarpenterImpl import net.corda.testing.contracts.DummyContract @@ -210,7 +209,7 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi private fun defaultFactory(): SerializerFactory { return SerializerFactoryBuilder.build(AllWhitelist, ClassCarpenterImpl(AllWhitelist, ClassLoader.getSystemClassLoader()), - evolutionSerializerProvider = FailIfEvolutionAttempted + allowEvolution = false ) } @@ -258,27 +257,27 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi @Test fun isPrimitive() { - assertTrue(isPrimitive(Character::class.java)) - assertTrue(isPrimitive(Boolean::class.java)) - assertTrue(isPrimitive(Byte::class.java)) - assertTrue(isPrimitive(UnsignedByte::class.java)) - assertTrue(isPrimitive(Short::class.java)) - assertTrue(isPrimitive(UnsignedShort::class.java)) - assertTrue(isPrimitive(Int::class.java)) - assertTrue(isPrimitive(UnsignedInteger::class.java)) - assertTrue(isPrimitive(Long::class.java)) - assertTrue(isPrimitive(UnsignedLong::class.java)) - assertTrue(isPrimitive(Float::class.java)) - assertTrue(isPrimitive(Double::class.java)) - assertTrue(isPrimitive(Decimal32::class.java)) - assertTrue(isPrimitive(Decimal64::class.java)) - assertTrue(isPrimitive(Decimal128::class.java)) - assertTrue(isPrimitive(Char::class.java)) - assertTrue(isPrimitive(Date::class.java)) - assertTrue(isPrimitive(UUID::class.java)) - assertTrue(isPrimitive(ByteArray::class.java)) - assertTrue(isPrimitive(String::class.java)) - assertTrue(isPrimitive(Symbol::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Character::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Boolean::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Byte::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(UnsignedByte::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Short::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(UnsignedShort::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Int::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(UnsignedInteger::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Long::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(UnsignedLong::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Float::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Double::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Decimal32::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Decimal64::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Decimal128::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Char::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Date::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(UUID::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(ByteArray::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(String::class.java)) + assertTrue(AMQPTypeIdentifiers.isPrimitive(Symbol::class.java)) } @Test @@ -475,10 +474,11 @@ class SerializationOutputTests(private val compression: CordaSerializationEncodi @Test fun `class constructor is invoked on deserialisation`() { compression == null || return // Manipulation of serialized bytes is invalid if they're compressed. - val ser = SerializationOutput(SerializerFactoryBuilder.build(AllWhitelist, + val serializerFactory = SerializerFactoryBuilder.build(AllWhitelist, ClassCarpenterImpl(AllWhitelist, ClassLoader.getSystemClassLoader()) - )) - val des = DeserializationInput(ser.serializerFactory) + ) + val ser = SerializationOutput(serializerFactory) + val des = DeserializationInput(serializerFactory) val serialisedOne = ser.serialize(NonZeroByte(1), compression).bytes val serialisedTwo = ser.serialize(NonZeroByte(2), compression).bytes diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationPropertyOrdering.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationPropertyOrdering.kt index 750a2f2c5c..5ffcfe2c90 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationPropertyOrdering.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationPropertyOrdering.kt @@ -1,12 +1,8 @@ package net.corda.serialization.internal.amqp import net.corda.core.serialization.ConstructorForDeserialization -import net.corda.serialization.internal.amqp.testutils.TestSerializationOutput -import net.corda.serialization.internal.amqp.testutils.deserialize -import net.corda.serialization.internal.amqp.testutils.serializeAndReturnSchema -import net.corda.serialization.internal.amqp.testutils.testDefaultFactoryNoEvolution +import net.corda.serialization.internal.amqp.testutils.* import org.junit.Test -import java.util.concurrent.ConcurrentHashMap import kotlin.test.assertEquals import org.apache.qpid.proton.amqp.Symbol import java.lang.reflect.Method @@ -17,7 +13,8 @@ class SerializationPropertyOrdering { companion object { val VERBOSE get() = false - val sf = testDefaultFactoryNoEvolution() + val registry = TestDescriptorBasedSerializerRegistry() + val sf = testDefaultFactoryNoEvolution(registry) } // Force object references to be ued to ensure we go through that code path @@ -100,25 +97,6 @@ class SerializationPropertyOrdering { assertEquals("e", this.fields[4].name) } - // Test needs to look at a bunch of private variables, change the access semantics for them - val fields : Map = mapOf ( - "setter" to PropertyAccessorGetterSetter::class.java.getDeclaredField("setter")).apply { - this.values.forEach { - it.isAccessible = true - } - } - - val serializersByDescriptor = sf.serializersByDescriptor - val schemaDescriptor = output.schema.types.first().descriptor.name - - // make sure that each property accessor has a setter to ensure we're using getter / setter instantiation - serializersByDescriptor.filterKeys { (it as Symbol) == schemaDescriptor }.values.apply { - assertEquals(1, this.size) - assertTrue(this.first() is ObjectSerializer) - val propertyAccessors = (this.first() as ObjectSerializer).propertySerializers.serializationOrder as List - propertyAccessors.forEach { property -> assertNotNull(fields["setter"]!!.get(property) as Method?) } - } - val input = DeserializationInput(sf).deserialize(output.obj) assertEquals(100, input.a) assertEquals(200, input.b) diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationSchemaTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationSchemaTests.kt index 7a56cbb953..a504b25a8c 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationSchemaTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/SerializationSchemaTests.kt @@ -17,87 +17,4 @@ val TESTING_CONTEXT = SerializationContextImpl(amqpMagic, emptyMap(), true, SerializationContext.UseCase.Testing, - null) - -// Test factory that lets us count the number of serializer registration attempts -class TestSerializerFactory( - wl: ClassWhitelist, - cl: ClassLoader -) : DefaultSerializerFactory(wl, ClassCarpenterImpl(wl, cl, false), DefaultEvolutionSerializerProvider, ::SerializerFingerPrinter) { - var registerCount = 0 - - override fun register(customSerializer: CustomSerializer) { - ++registerCount - return super.register(customSerializer) - } -} - -// Instance of our test factory counting registration attempts. Sucks its global, but for testing purposes this -// is the easiest way of getting access to the object. -val testFactory = TestSerializerFactory(TESTING_CONTEXT.whitelist, TESTING_CONTEXT.deserializationClassLoader) - -// Serializer factory factory, plugs into the SerializationScheme and controls which factory type -// we make for each use case. For our tests we need to make sure if its the Testing use case we return -// the global factory object created above that counts registrations. -class TestSerializerFactoryFactory : SerializerFactoryFactoryImpl() { - override fun make(context: SerializationContext) = - when (context.useCase) { - SerializationContext.UseCase.Testing -> testFactory - else -> super.make(context) - } -} - -class AMQPTestSerializationScheme : AbstractAMQPSerializationScheme(emptySet(), AccessOrderLinkedHashMap { 128 }, TestSerializerFactoryFactory()) { - override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory { - throw UnsupportedOperationException() - } - - override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory { - throw UnsupportedOperationException() - } - - override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase) = true -} - -// Test SerializationFactory that wraps a serialization scheme that just allows us to call .serialize. -// Returns the testing scheme we created above that wraps the testing factory. -class TestSerializationFactory : SerializationFactory() { - private val scheme = AMQPTestSerializationScheme() - - override fun deserialize( - byteSequence: ByteSequence, - clazz: Class, context: - SerializationContext - ): T { - throw UnsupportedOperationException() - } - - override fun deserializeWithCompatibleContext( - byteSequence: ByteSequence, - clazz: Class, - context: SerializationContext - ): ObjectWithCompatibleContext { - throw UnsupportedOperationException() - } - - override fun serialize(obj: T, context: SerializationContext) = scheme.serialize(obj, context) -} - -// The actual test -class SerializationSchemaTests { - @Test - fun onlyRegisterCustomSerializersOnce() { - @CordaSerializable - data class C(val a: Int) - - val c = C(1) - val testSerializationFactory = TestSerializationFactory() - val expectedCustomSerializerCount = 41 - - assertEquals(0, testFactory.registerCount) - c.serialize(testSerializationFactory, TESTING_CONTEXT) - assertEquals(expectedCustomSerializerCount, testFactory.registerCount) - c.serialize(testSerializationFactory, TESTING_CONTEXT) - assertEquals(expectedCustomSerializerCount, testFactory.registerCount) - } -} \ No newline at end of file + null) \ No newline at end of file diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/StaticInitialisationOfSerializedObjectTest.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/StaticInitialisationOfSerializedObjectTest.kt index f8450faf49..ab13ffa712 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/StaticInitialisationOfSerializedObjectTest.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/StaticInitialisationOfSerializedObjectTest.kt @@ -6,6 +6,7 @@ import net.corda.serialization.internal.AllWhitelist import net.corda.serialization.internal.amqp.testutils.deserialize import net.corda.serialization.internal.carpenter.ClassCarpenterImpl import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.Ignore import org.junit.Test import java.io.NotSerializableException import java.lang.reflect.Type @@ -44,6 +45,7 @@ class StaticInitialisationOfSerializedObjectTest { C() } + @Ignore("Suppressing this, as it depends on obtaining internal access to serialiser cache") @Test fun kotlinObjectWithCompanionObject() { data class D(val c: C) @@ -63,7 +65,7 @@ class StaticInitialisationOfSerializedObjectTest { // build a serializer for type D without an instance of it to serialise, since // we can't actually construct one - sf.get(null, D::class.java) + sf.get(D::class.java) // post creation of the serializer we should have two elements in the map, this // proves we didn't statically construct an instance of C when building the serializer diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/testutils/AMQPTestUtils.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/testutils/AMQPTestUtils.kt index b248d11c2d..77dc08ab79 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/testutils/AMQPTestUtils.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/testutils/AMQPTestUtils.kt @@ -18,20 +18,45 @@ import java.io.File.separatorChar import java.io.NotSerializableException import java.nio.file.StandardCopyOption.REPLACE_EXISTING -fun testDefaultFactory() = SerializerFactoryBuilder.build(AllWhitelist, - ClassCarpenterImpl(AllWhitelist, ClassLoader.getSystemClassLoader()) -) +/** + * For tests that want to see inside the serializer registry + */ +class TestDescriptorBasedSerializerRegistry : DescriptorBasedSerializerRegistry { + val contents = mutableMapOf>() -fun testDefaultFactoryNoEvolution(): SerializerFactory { - return SerializerFactoryBuilder.build( - AllWhitelist, - ClassCarpenterImpl(AllWhitelist, ClassLoader.getSystemClassLoader()), - FailIfEvolutionAttempted) + override fun get(descriptor: String): AMQPSerializer? = contents[descriptor] + + override fun set(descriptor: String, serializer: AMQPSerializer) { + contents.putIfAbsent(descriptor, serializer) + } + + override fun getOrBuild(descriptor: String, builder: () -> AMQPSerializer): AMQPSerializer = + get(descriptor) ?: builder().also { set(descriptor, it) } } -fun testDefaultFactoryWithWhitelist() = SerializerFactoryBuilder.build(EmptyWhitelist, - ClassCarpenterImpl(EmptyWhitelist, ClassLoader.getSystemClassLoader()) -) +@JvmOverloads +fun testDefaultFactory(descriptorBasedSerializerRegistry: DescriptorBasedSerializerRegistry = + DefaultDescriptorBasedSerializerRegistry()) = + SerializerFactoryBuilder.build( + AllWhitelist, + ClassCarpenterImpl(AllWhitelist, ClassLoader.getSystemClassLoader()), + descriptorBasedSerializerRegistry = descriptorBasedSerializerRegistry) + +@JvmOverloads +fun testDefaultFactoryNoEvolution(descriptorBasedSerializerRegistry: DescriptorBasedSerializerRegistry = + DefaultDescriptorBasedSerializerRegistry()): SerializerFactory = + SerializerFactoryBuilder.build( + AllWhitelist, + ClassCarpenterImpl(AllWhitelist, ClassLoader.getSystemClassLoader()), + descriptorBasedSerializerRegistry = descriptorBasedSerializerRegistry, + allowEvolution = false) + +@JvmOverloads +fun testDefaultFactoryWithWhitelist(descriptorBasedSerializerRegistry: DescriptorBasedSerializerRegistry = + DefaultDescriptorBasedSerializerRegistry()) = + SerializerFactoryBuilder.build(EmptyWhitelist, + ClassCarpenterImpl(EmptyWhitelist, ClassLoader.getSystemClassLoader()), + descriptorBasedSerializerRegistry = descriptorBasedSerializerRegistry) class TestSerializationOutput( private val verbose: Boolean, diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/CalculatedValuesToClassCarpenterTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/CalculatedValuesToClassCarpenterTests.kt deleted file mode 100644 index 893f81833b..0000000000 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/CalculatedValuesToClassCarpenterTests.kt +++ /dev/null @@ -1,101 +0,0 @@ -package net.corda.serialization.internal.carpenter - -import net.corda.core.serialization.SerializableCalculatedProperty -import net.corda.serialization.internal.AllWhitelist -import net.corda.serialization.internal.amqp.CompositeType -import net.corda.serialization.internal.amqp.DeserializationInput -import net.corda.serialization.internal.amqp.testutils.deserializeAndReturnEnvelope -import net.corda.serialization.internal.amqp.testutils.testDefaultFactoryNoEvolution -import org.junit.Test -import kotlin.test.assertEquals - -class CalculatedValuesToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { - - interface Parent { - @get:SerializableCalculatedProperty - val doubled: Int - } - - @Test - fun calculatedValues() { - data class C(val i: Int): Parent { - @get:SerializableCalculatedProperty - val squared = (i * i).toString() - - override val doubled get() = i * 2 - } - - val factory = testDefaultFactoryNoEvolution() - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(C(2))) - val amqpObj = obj.obj - val serSchema = obj.envelope.schema - - assertEquals(2, amqpObj.i) - assertEquals("4", amqpObj.squared) - assertEquals(2, serSchema.types.size) - require(serSchema.types[0] is CompositeType) - - val concrete = serSchema.types[0] as CompositeType - assertEquals(3, concrete.fields.size) - assertEquals("doubled", concrete.fields[0].name) - assertEquals("int", concrete.fields[0].type) - assertEquals("i", concrete.fields[1].name) - assertEquals("int", concrete.fields[1].type) - assertEquals("squared", concrete.fields[2].name) - assertEquals("string", concrete.fields[2].type) - - val l1 = serSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - assertEquals(0, l1.size) - val mangleSchema = serSchema.mangleNames(listOf((classTestName("C")))) - val l2 = mangleSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - val aName = mangleName(classTestName("C")) - - assertEquals(1, l2.size) - val aSchema = l2.carpenterSchemas.find { it.name == aName }!! - - val pinochio = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema) - val p = pinochio.constructors[0].newInstance(4, 2, "4") - - assertEquals(pinochio.getMethod("getI").invoke(p), amqpObj.i) - assertEquals(pinochio.getMethod("getSquared").invoke(p), amqpObj.squared) - assertEquals(pinochio.getMethod("getDoubled").invoke(p), amqpObj.doubled) - - val upcast = p as Parent - assertEquals(upcast.doubled, amqpObj.doubled) - } - - @Test - fun implementingClassDoesNotCalculateValue() { - class C(override val doubled: Int): Parent - - val factory = testDefaultFactoryNoEvolution() - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(C(5))) - val amqpObj = obj.obj - val serSchema = obj.envelope.schema - - assertEquals(2, serSchema.types.size) - require(serSchema.types[0] is CompositeType) - - val concrete = serSchema.types[0] as CompositeType - assertEquals(1, concrete.fields.size) - assertEquals("doubled", concrete.fields[0].name) - assertEquals("int", concrete.fields[0].type) - - val l1 = serSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - assertEquals(0, l1.size) - val mangleSchema = serSchema.mangleNames(listOf((classTestName("C")))) - val l2 = mangleSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - val aName = mangleName(classTestName("C")) - - assertEquals(1, l2.size) - val aSchema = l2.carpenterSchemas.find { it.name == aName }!! - - val pinochio = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema) - val p = pinochio.constructors[0].newInstance(5) - - assertEquals(pinochio.getMethod("getDoubled").invoke(p), amqpObj.doubled) - - val upcast = p as Parent - assertEquals(upcast.doubled, amqpObj.doubled) - } -} \ No newline at end of file diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenterTest.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenterTest.kt index ac510642d7..d9a7b19fd6 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenterTest.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenterTest.kt @@ -27,7 +27,7 @@ class ClassCarpenterTest { @Test fun empty() { - val clazz = cc.build(ClassSchema("gen.EmptyClass", emptyMap(), null)) + val clazz = cc.build(ClassSchema("gen.EmptyClass", emptyMap())) assertEquals(0, clazz.nonSyntheticFields.size) assertEquals(2, clazz.nonSyntheticMethods.size) // get, toString assertEquals(0, clazz.declaredConstructors[0].parameterCount) @@ -97,8 +97,8 @@ class ClassCarpenterTest { @Test(expected = DuplicateNameException::class) fun duplicates() { - cc.build(ClassSchema("gen.EmptyClass", emptyMap(), null)) - cc.build(ClassSchema("gen.EmptyClass", emptyMap(), null)) + cc.build(ClassSchema("gen.EmptyClass", emptyMap())) + cc.build(ClassSchema("gen.EmptyClass", emptyMap())) } @Test diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenterTestUtils.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenterTestUtils.kt index 0068399ed6..0ec2fed1ab 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenterTestUtils.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenterTestUtils.kt @@ -1,41 +1,14 @@ package net.corda.serialization.internal.carpenter +import com.google.common.reflect.TypeToken import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.SerializedBytes import net.corda.serialization.internal.amqp.* -import net.corda.serialization.internal.amqp.Field -import net.corda.serialization.internal.amqp.Schema +import net.corda.serialization.internal.amqp.testutils.deserializeAndReturnEnvelope import net.corda.serialization.internal.amqp.testutils.serialize import net.corda.serialization.internal.amqp.testutils.testName - -fun mangleName(name: String) = "${name}__carpenter" - -/** - * given a list of class names work through the amqp envelope schema and alter any that - * match in the fashion defined above - */ -fun Schema.mangleNames(names: List): Schema { - val newTypes: MutableList = mutableListOf() - - for (type in types) { - val newName = if (type.name in names) mangleName(type.name) else type.name - val newProvides = type.provides.map { if (it in names) mangleName(it) else it } - val newFields = mutableListOf() - - (type as CompositeType).fields.forEach { - val fieldType = if (it.type in names) mangleName(it.type) else it.type - val requires = - if (it.requires.isNotEmpty() && (it.requires[0] in names)) listOf(mangleName(it.requires[0])) - else it.requires - - newFields.add(it.copy(type = fieldType, requires = requires)) - } - - newTypes.add(type.copy(name = newName, provides = newProvides, fields = newFields)) - } - - return Schema(types = newTypes) -} +import net.corda.serialization.internal.model.* +import org.junit.Assert.assertTrue /** * Custom implementation of a [SerializerFactory] where we need to give it a class carpenter @@ -48,7 +21,78 @@ open class AmqpCarpenterBase(whitelist: ClassWhitelist) { var cc = ClassCarpenterImpl(whitelist = whitelist) var factory = serializerFactoryExternalCarpenter(cc) - fun serialise(obj: T): SerializedBytes = SerializationOutput(factory).serialize(obj) - @Suppress("NOTHING_TO_INLINE") - inline fun classTestName(clazz: String) = "${this.javaClass.name}\$${testName()}\$$clazz" + protected val remoteTypeModel = AMQPRemoteTypeModel() + protected val typeLoader = ClassCarpentingTypeLoader(SchemaBuildingRemoteTypeCarpenter(cc), cc.classloader) + + protected inline fun T.roundTrip(): ObjectAndEnvelope = + DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(this)) + + protected val Envelope.typeInformation: Map get() = + remoteTypeModel.interpret(SerializationSchemas(schema, transformsSchema)) + + protected inline fun Envelope.typeInformationFor(): RemoteTypeInformation { + val interpreted = typeInformation + val type = object : TypeToken() {}.type + return interpreted.values.find { it.typeIdentifier == TypeIdentifier.forGenericType(type) } + as RemoteTypeInformation + } + + protected inline fun Envelope.getMangled(): RemoteTypeInformation = + typeInformationFor().mangle() + + protected fun serialise(obj: T): SerializedBytes = SerializationOutput(factory).serialize(obj) + + protected inline fun RemoteTypeInformation.mangle(): RemoteTypeInformation { + val from = TypeIdentifier.forGenericType(object : TypeToken() {}.type) + return rename(from, from.mangle()) + } + + protected fun TypeIdentifier.mangle(): TypeIdentifier = when(this) { + is TypeIdentifier.Unparameterised -> copy(name = name + "_carpenter") + is TypeIdentifier.Parameterised -> copy(name = name + "_carpenter") + is TypeIdentifier.Erased -> copy(name = name + "_carpenter") + is TypeIdentifier.ArrayOf -> copy(componentType = componentType.mangle()) + else -> this + } + + protected fun TypeIdentifier.rename(from: TypeIdentifier, to: TypeIdentifier): TypeIdentifier = when(this) { + from -> to.rename(from, to) + is TypeIdentifier.Parameterised -> copy(parameters = parameters.map { it.rename(from, to) }) + is TypeIdentifier.ArrayOf -> copy(componentType = componentType.rename(from, to)) + else -> this + } + + protected fun RemoteTypeInformation.rename(from: TypeIdentifier, to: TypeIdentifier): RemoteTypeInformation = when(this) { + is RemoteTypeInformation.Composable -> copy( + typeIdentifier = typeIdentifier.rename(from, to), + properties = properties.mapValues { (_, property) -> property.copy(type = property.type.rename(from, to)) }, + interfaces = interfaces.map { it.rename(from, to) }, + typeParameters = typeParameters.map { it.rename(from, to) }) + is RemoteTypeInformation.Unparameterised -> copy(typeIdentifier = typeIdentifier.rename(from, to)) + is RemoteTypeInformation.Parameterised -> copy( + typeIdentifier = typeIdentifier.rename(from, to), + typeParameters = typeParameters.map { it.rename(from, to) }) + is RemoteTypeInformation.AnInterface -> copy( + typeIdentifier = typeIdentifier.rename(from, to), + properties = properties.mapValues { (_, property) -> property.copy(type = property.type.rename(from, to)) }, + interfaces = interfaces.map { it.rename(from, to) }, + typeParameters = typeParameters.map { it.rename(from, to) }) + is RemoteTypeInformation.AnArray -> copy(componentType = componentType.rename(from, to)) + is RemoteTypeInformation.AnEnum -> copy( + typeIdentifier = typeIdentifier.rename(from, to)) + else -> this + } + + protected fun RemoteTypeInformation.load(): Class<*> = + typeLoader.load(listOf(this))[typeIdentifier]!!.asClass() + + protected fun assertCanLoadAll(vararg types: RemoteTypeInformation) { + assertTrue(typeLoader.load(types.asList()).keys.containsAll(types.map { it.typeIdentifier })) + } + + protected fun Class<*>.new(vararg constructorParams: Any?) = + constructors[0].newInstance(*constructorParams)!! + + protected fun Any.get(propertyName: String): Any = + this::class.java.getMethod("get${propertyName.capitalize()}").invoke(this) } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/CompositeMemberCompositeSchemaToClassCarpenterTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/CompositeMemberCompositeSchemaToClassCarpenterTests.kt index 5e5b262a6f..b1700ae8c2 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/CompositeMemberCompositeSchemaToClassCarpenterTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/CompositeMemberCompositeSchemaToClassCarpenterTests.kt @@ -2,13 +2,11 @@ package net.corda.serialization.internal.carpenter import net.corda.core.serialization.CordaSerializable import net.corda.serialization.internal.AllWhitelist -import net.corda.serialization.internal.amqp.CompositeType -import net.corda.serialization.internal.amqp.DeserializationInput -import net.corda.serialization.internal.amqp.testutils.deserializeAndReturnEnvelope import org.junit.Test +import java.io.NotSerializableException +import java.util.* import kotlin.test.assertEquals -import kotlin.test.assertFalse -import kotlin.test.assertTrue +import kotlin.test.assertFailsWith @CordaSerializable interface I_ { @@ -16,258 +14,96 @@ interface I_ { } class CompositeMembers : AmqpCarpenterBase(AllWhitelist) { - @Test - fun bothKnown() { - val testA = 10 - val testB = 20 + @Test + fun parentIsUnknown() { @CordaSerializable data class A(val a: Int) @CordaSerializable data class B(val a: A, var b: Int) - val b = B(A(testA), testB) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) + val (_, envelope) = B(A(10), 20).roundTrip() - val amqpObj = obj.obj - - assertEquals(testB, amqpObj.b) - assertEquals(testA, amqpObj.a.a) - assertEquals(2, obj.envelope.schema.types.size) - require(obj.envelope.schema.types[0] is CompositeType) - require(obj.envelope.schema.types[1] is CompositeType) - - var amqpSchemaA: CompositeType? = null - var amqpSchemaB: CompositeType? = null - - for (type in obj.envelope.schema.types) { - when (type.name.split("$").last()) { - "A" -> amqpSchemaA = type as CompositeType - "B" -> amqpSchemaB = type as CompositeType - } - } - - require(amqpSchemaA != null) - require(amqpSchemaB != null) - - // Just ensure the amqp schema matches what we want before we go messing - // around with the internals - assertEquals(1, amqpSchemaA?.fields?.size) - assertEquals("a", amqpSchemaA!!.fields[0].name) - assertEquals("int", amqpSchemaA.fields[0].type) - - assertEquals(2, amqpSchemaB?.fields?.size) - assertEquals("a", amqpSchemaB!!.fields[0].name) - assertEquals(classTestName("A"), amqpSchemaB.fields[0].type) - assertEquals("b", amqpSchemaB.fields[1].name) - assertEquals("int", amqpSchemaB.fields[1].type) - - val metaSchema = obj.envelope.schema.carpenterSchema(ClassLoader.getSystemClassLoader()) - - // if we know all the classes there is nothing to really achieve here - require(metaSchema.carpenterSchemas.isEmpty()) - require(metaSchema.dependsOn.isEmpty()) - require(metaSchema.dependencies.isEmpty()) + // We load an unknown class, B_mangled, which includes a reference to a known class, A. + assertCanLoadAll(envelope.getMangled()) } - // you cannot have an element of a composite class we know about - // that is unknown as that should be impossible. If we have the containing - // class in the class path then we must have all of it's constituent elements - @Test(expected = UncarpentableException::class) - fun nestedIsUnknown() { - val testA = 10 - val testB = 20 - + @Test + fun bothAreUnknown() { @CordaSerializable data class A(override val a: Int) : I_ @CordaSerializable data class B(val a: A, var b: Int) - val b = B(A(testA), testB) + val (_, envelope) = B(A(10), 20).roundTrip() - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) - val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"))) - - amqpSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) + // We load an unknown class, B_mangled, which includes a reference to an unknown class, A_mangled. + // For this to work, we must include A_mangled in our set of classes to load. + assertCanLoadAll(envelope.getMangled().mangle(), envelope.getMangled()) } @Test - fun ParentIsUnknown() { - val testA = 10 - val testB = 20 - + fun oneIsUnknown() { @CordaSerializable data class A(override val a: Int) : I_ @CordaSerializable data class B(val a: A, var b: Int) - val b = B(A(testA), testB) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) + val (_, envelope) = B(A(10), 20).roundTrip() - val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("B"))) - val carpenterSchema = amqpSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) + // We load an unknown class, B_mangled, which includes a reference to an unknown class, A_mangled. + // This will fail, because A_mangled is not included in our set of classes to load. + assertFailsWith { assertCanLoadAll(envelope.getMangled().mangle()) } + } - assertEquals(1, carpenterSchema.size) + // See https://github.com/corda/corda/issues/4107 + @Test + fun withUUID() { + @CordaSerializable + data class IOUStateData( + val value: Int, + val ref: UUID, + val newValue: String? = null + ) - val metaCarpenter = MetaCarpenter(carpenterSchema, ClassCarpenterImpl(whitelist = AllWhitelist)) - - metaCarpenter.build() - - require(mangleName(classTestName("B")) in metaCarpenter.objects) + val uuid = UUID.randomUUID() + val(_, envelope) = IOUStateData(10, uuid, "new value").roundTrip() + val recarpented = envelope.getMangled().load() + val instance = recarpented.new(null, uuid, 10) + assertEquals(uuid, instance.get("ref")) } @Test - fun BothUnknown() { - val testA = 10 - val testB = 20 + fun mapWithUnknown() { + data class C(val a: Int) + data class D(val m: Map) + val (_, envelope) = D(mapOf("c" to C(1))).roundTrip() - @CordaSerializable - data class A(override val a: Int) : I_ + val infoForD = envelope.typeInformationFor().mangle() + val mangledMap = envelope.typeInformation.values.find { it.typeIdentifier.name == "java.util.Map" }!!.mangle() + val mangledC = envelope.getMangled() - @CordaSerializable - data class B(val a: A, var b: Int) + assertEquals( + "java.util.Map", + mangledMap.prettyPrint(false)) - val b = B(A(testA), testB) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) - val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B"))) - val carpenterSchema = amqpSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - - // just verify we're in the expected initial state, A is carpentable, B is not because - // it depends on A and the dependency chains are in place - assertEquals(1, carpenterSchema.size) - assertEquals(mangleName(classTestName("A")), carpenterSchema.carpenterSchemas.first().name) - assertEquals(1, carpenterSchema.dependencies.size) - require(mangleName(classTestName("B")) in carpenterSchema.dependencies) - assertEquals(1, carpenterSchema.dependsOn.size) - require(mangleName(classTestName("A")) in carpenterSchema.dependsOn) - - val metaCarpenter = TestMetaCarpenter(carpenterSchema, ClassCarpenterImpl(whitelist = AllWhitelist)) - - assertEquals(0, metaCarpenter.objects.size) - - // first iteration, carpent A, resolve deps and mark B as carpentable - metaCarpenter.build() - - // one build iteration should have carpetned up A and worked out that B is now buildable - // given it's depedencies have been satisfied - assertTrue(mangleName(classTestName("A")) in metaCarpenter.objects) - assertFalse(mangleName(classTestName("B")) in metaCarpenter.objects) - - assertEquals(1, carpenterSchema.carpenterSchemas.size) - assertEquals(mangleName(classTestName("B")), carpenterSchema.carpenterSchemas.first().name) - assertTrue(carpenterSchema.dependencies.isEmpty()) - assertTrue(carpenterSchema.dependsOn.isEmpty()) - - // second manual iteration, will carpent B - metaCarpenter.build() - require(mangleName(classTestName("A")) in metaCarpenter.objects) - require(mangleName(classTestName("B")) in metaCarpenter.objects) - - // and we must be finished - assertTrue(carpenterSchema.carpenterSchemas.isEmpty()) + assertCanLoadAll(infoForD, mangledMap, mangledC) } - @Test(expected = UncarpentableException::class) - @Suppress("UNUSED") - fun nestedIsUnknownInherited() { - val testA = 10 - val testB = 20 - val testC = 30 - - @CordaSerializable - open class A(val a: Int) - - @CordaSerializable - class B(a: Int, var b: Int) : A(a) - - @CordaSerializable - data class C(val b: B, var c: Int) - - val c = C(B(testA, testB), testC) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(c)) - - val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B"))) - - amqpSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - } - - @Test(expected = UncarpentableException::class) - @Suppress("UNUSED") - fun nestedIsUnknownInheritedUnknown() { - val testA = 10 - val testB = 20 - val testC = 30 - - @CordaSerializable - open class A(val a: Int) - - @CordaSerializable - class B(a: Int, var b: Int) : A(a) - - @CordaSerializable - data class C(val b: B, var c: Int) - - val c = C(B(testA, testB), testC) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(c)) - - val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B"))) - - amqpSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - } - - @Suppress("UNUSED") - @Test(expected = UncarpentableException::class) - fun parentsIsUnknownWithUnknownInheritedMember() { - val testA = 10 - val testB = 20 - val testC = 30 - - @CordaSerializable - open class A(val a: Int) - - @CordaSerializable - class B(a: Int, var b: Int) : A(a) - - @CordaSerializable - data class C(val b: B, var c: Int) - - val c = C(B(testA, testB), testC) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(c)) - - val carpenterSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B"))) - TestMetaCarpenter(carpenterSchema.carpenterSchema( - ClassLoader.getSystemClassLoader()), ClassCarpenterImpl(whitelist = AllWhitelist)) - } - - /* - * TODO serializer doesn't support inheritnace at the moment, when it does this should work @Test - fun `inheritance`() { - val testA = 10 - val testB = 20 + fun parameterisedNonCollectionWithUnknown() { + data class C(val a: Int) + data class NotAMap(val key: K, val value: V) + data class D(val m: NotAMap) + val (_, envelope) = D(NotAMap("c" , C(1))).roundTrip() - @CordaSerializable - open class A(open val a: Int) + val infoForD = envelope.typeInformationFor().mangle() + val mangledNotAMap = envelope.typeInformationFor>().mangle() + val mangledC = envelope.getMangled() - @CordaSerializable - class B(override val a: Int, val b: Int) : A (a) - - val b = B(testA, testB) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) - - require(obj.obj is B) - - val carpenterSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B"))) - val metaCarpenter = TestMetaCarpenter(carpenterSchema.carpenterSchema()) - - assertEquals(1, metaCarpenter.schemas.carpenterSchemas.size) - assertEquals(mangleNames(classTestName("B")), metaCarpenter.schemas.carpenterSchemas.first().name) - assertEquals(1, metaCarpenter.schemas.dependencies.size) - assertTrue(mangleNames(classTestName("A")) in metaCarpenter.schemas.dependencies) + assertCanLoadAll(infoForD, mangledNotAMap, mangledC) } - */ } - diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/InheritanceSchemaToClassCarpenterTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/InheritanceSchemaToClassCarpenterTests.kt index 6a1a2c6e3e..336471c9ac 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/InheritanceSchemaToClassCarpenterTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/InheritanceSchemaToClassCarpenterTests.kt @@ -2,10 +2,9 @@ package net.corda.serialization.internal.carpenter import net.corda.core.serialization.CordaSerializable import net.corda.serialization.internal.AllWhitelist -import net.corda.serialization.internal.amqp.DeserializationInput import org.junit.Test import kotlin.test.* -import net.corda.serialization.internal.amqp.testutils.deserializeAndReturnEnvelope +import java.io.NotSerializableException @CordaSerializable interface J { @@ -39,172 +38,68 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { fun interfaceParent1() { class A(override val j: Int) : J - val testJ = 20 - val a = A(testJ) + val (_, env) = A(20).roundTrip() + val mangledA = env.getMangled() - assertEquals(testJ, a.j) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) - val serSchema = obj.envelope.schema - assertEquals(2, serSchema.types.size) - val l1 = serSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) + val carpentedA = mangledA.load() + val carpentedInstance = carpentedA.new(20) - // since we're using an envelope generated by seilaising classes defined locally - // it's extremely unlikely we'd need to carpent any classes - assertEquals(0, l1.size) + assertEquals(20, carpentedInstance.get("j")) - val mangleSchema = serSchema.mangleNames(listOf(classTestName("A"))) - val l2 = mangleSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - assertEquals(1, l2.size) - - val aSchema = l2.carpenterSchemas.find { it.name == mangleName(classTestName("A")) } - assertNotEquals(null, aSchema) - assertEquals(mangleName(classTestName("A")), aSchema!!.name) - assertEquals(1, aSchema.interfaces.size) - assertEquals(J::class.java, aSchema.interfaces[0]) - - val aBuilder = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema) - val objJ = aBuilder.constructors[0].newInstance(testJ) - val j = objJ as J - - assertEquals(aBuilder.getMethod("getJ").invoke(objJ), testJ) - assertEquals(a.j, j.j) + val asJ = carpentedInstance as J + assertEquals(20, asJ.j) } @Test fun interfaceParent2() { class A(override val j: Int, val jj: Int) : J - val testJ = 20 - val testJJ = 40 - val a = A(testJ, testJJ) + val (_, env) = A(23, 42).roundTrip() + val carpentedA = env.getMangled().load() + val carpetedInstance = carpentedA.constructors[0].newInstance(23, 42) - assertEquals(testJ, a.j) - assertEquals(testJJ, a.jj) + assertEquals(23, carpetedInstance.get("j")) + assertEquals(42, carpetedInstance.get("jj")) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) - - val serSchema = obj.envelope.schema - - assertEquals(2, serSchema.types.size) - - val l1 = serSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - - assertEquals(0, l1.size) - - val mangleSchema = serSchema.mangleNames(listOf(classTestName("A"))) - val aName = mangleName(classTestName("A")) - val l2 = mangleSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - - assertEquals(1, l2.size) - - val aSchema = l2.carpenterSchemas.find { it.name == aName } - - assertNotEquals(null, aSchema) - - assertEquals(aName, aSchema!!.name) - assertEquals(1, aSchema.interfaces.size) - assertEquals(J::class.java, aSchema.interfaces[0]) - - val aBuilder = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema) - val objJ = aBuilder.constructors[0].newInstance(testJ, testJJ) - val j = objJ as J - - assertEquals(aBuilder.getMethod("getJ").invoke(objJ), testJ) - assertEquals(aBuilder.getMethod("getJj").invoke(objJ), testJJ) - - assertEquals(a.j, j.j) + val asJ = carpetedInstance as J + assertEquals(23, asJ.j) } @Test fun multipleInterfaces() { - val testI = 20 - val testII = 40 - class A(override val i: Int, override val ii: Int) : I, II - val a = A(testI, testII) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + val (_, env) = A(23, 42).roundTrip() + val carpentedA = env.getMangled().load() + val carpetedInstance = carpentedA.constructors[0].newInstance(23, 42) - val serSchema = obj.envelope.schema + assertEquals(23, carpetedInstance.get("i")) + assertEquals(42, carpetedInstance.get("ii")) - assertEquals(3, serSchema.types.size) + val i = carpetedInstance as I + val ii = carpetedInstance as II - val l1 = serSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - - // since we're using an envelope generated by serialising classes defined locally - // it's extremely unlikely we'd need to carpent any classes - assertEquals(0, l1.size) - - // pretend we don't know the class we've been sent, i.e. it's unknown to the class loader, and thus - // needs some carpentry - val mangleSchema = serSchema.mangleNames(listOf(classTestName("A"))) - val l2 = mangleSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - val aName = mangleName(classTestName("A")) - - assertEquals(1, l2.size) - - val aSchema = l2.carpenterSchemas.find { it.name == aName } - - assertNotEquals(null, aSchema) - assertEquals(aName, aSchema!!.name) - assertEquals(2, aSchema.interfaces.size) - assertTrue(I::class.java in aSchema.interfaces) - assertTrue(II::class.java in aSchema.interfaces) - - val aBuilder = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema) - val objA = aBuilder.constructors[0].newInstance(testI, testII) - val i = objA as I - val ii = objA as II - - assertEquals(aBuilder.getMethod("getI").invoke(objA), testI) - assertEquals(aBuilder.getMethod("getIi").invoke(objA), testII) - assertEquals(a.i, i.i) - assertEquals(a.ii, ii.ii) + assertEquals(23, i.i) + assertEquals(42, ii.ii) } @Test fun nestedInterfaces() { class A(override val i: Int, override val iii: Int) : III - val testI = 20 - val testIII = 60 - val a = A(testI, testIII) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + val (_, env) = A(23, 42).roundTrip() + val carpentedA = env.getMangled().load() + val carpetedInstance = carpentedA.constructors[0].newInstance(23, 42) - val serSchema = obj.envelope.schema + assertEquals(23, carpetedInstance.get("i")) + assertEquals(42, carpetedInstance.get("iii")) - assertEquals(3, serSchema.types.size) + val i = carpetedInstance as I + val iii = carpetedInstance as III - val l1 = serSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - - // since we're using an envelope generated by serialising classes defined locally - // it's extremely unlikely we'd need to carpent any classes - assertEquals(0, l1.size) - - val mangleSchema = serSchema.mangleNames(listOf(classTestName("A"))) - val l2 = mangleSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - val aName = mangleName(classTestName("A")) - - assertEquals(1, l2.size) - - val aSchema = l2.carpenterSchemas.find { it.name == aName } - - assertNotEquals(null, aSchema) - assertEquals(aName, aSchema!!.name) - assertEquals(2, aSchema.interfaces.size) - assertTrue(I::class.java in aSchema.interfaces) - assertTrue(III::class.java in aSchema.interfaces) - - val aBuilder = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema) - val objA = aBuilder.constructors[0].newInstance(testI, testIII) - val i = objA as I - val iii = objA as III - - assertEquals(aBuilder.getMethod("getI").invoke(objA), testI) - assertEquals(aBuilder.getMethod("getIii").invoke(objA), testIII) - assertEquals(a.i, i.i) - assertEquals(a.i, iii.i) - assertEquals(a.iii, iii.iii) + assertEquals(23, i.i) + assertEquals(23, iii.i) + assertEquals(42, iii.iii) } @Test @@ -212,237 +107,60 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { class A(override val i: Int) : I class B(override val i: I, override val iiii: Int) : IIII - val testI = 25 - val testIIII = 50 - val a = A(testI) - val b = B(a, testIIII) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) + val (_, env) = B(A(23), 42).roundTrip() + val carpentedA = env.getMangled().load() + val carpentedB = env.getMangled().load() - val serSchema = obj.envelope.schema + val carpentedAInstance = carpentedA.new(23) + val carpentedBInstance = carpentedB.new(carpentedAInstance, 42) - // Expected classes are - // * class A - // * class A's interface (class I) - // * class B - // * class B's interface (class IIII) - assertEquals(4, serSchema.types.size) - - val mangleSchema = serSchema.mangleNames(listOf(classTestName("A"), classTestName("B"))) - val cSchema = mangleSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - val aName = mangleName(classTestName("A")) - val bName = mangleName(classTestName("B")) - - assertEquals(2, cSchema.size) - - val aCarpenterSchema = cSchema.carpenterSchemas.find { it.name == aName } - val bCarpenterSchema = cSchema.carpenterSchemas.find { it.name == bName } - - assertNotEquals(null, aCarpenterSchema) - assertNotEquals(null, bCarpenterSchema) - - val cc = ClassCarpenterImpl(whitelist = AllWhitelist) - val cc2 = ClassCarpenterImpl(whitelist = AllWhitelist) - val bBuilder = cc.build(bCarpenterSchema!!) - bBuilder.constructors[0].newInstance(a, testIIII) - - val aBuilder = cc.build(aCarpenterSchema!!) - val objA = aBuilder.constructors[0].newInstance(testI) - - // build a second B this time using our constructed instance of A and not the - // local one we pre defined - bBuilder.constructors[0].newInstance(objA, testIIII) - - // whittle and instantiate a different A with a new class loader - val aBuilder2 = cc2.build(aCarpenterSchema) - val objA2 = aBuilder2.constructors[0].newInstance(testI) - - bBuilder.constructors[0].newInstance(objA2, testIIII) + val iiii = carpentedBInstance as IIII + assertEquals(23, iiii.i.i) + assertEquals(42, iiii.iiii) } - // if we remove the nested interface we should get an error as it's impossible - // to have a concrete class loaded without having access to all of it's elements - @Test(expected = UncarpentableException::class) + @Test fun memberInterface2() { class A(override val i: Int) : I - class B(override val i: I, override val iiii: Int) : IIII - val testI = 25 - val testIIII = 50 - val a = A(testI) - val b = B(a, testIIII) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) + val (_, env) = A(23).roundTrip() - val serSchema = obj.envelope.schema - - // The classes we're expecting to find: - // * class A - // * class A's interface (class I) - // * class B - // * class B's interface (class IIII) - assertEquals(4, serSchema.types.size) - - // ignore the return as we expect this to throw - serSchema.mangleNames(listOf( - classTestName("A"), "${this.javaClass.`package`.name}.I")).carpenterSchema(ClassLoader.getSystemClassLoader()) + // if we remove the nested interface we should get an error as it's impossible + // to have a concrete class loaded without having access to all of it's elements + assertFailsWith { assertCanLoadAll(env.getMangled().mangle()) } } @Test fun interfaceAndImplementation() { class A(override val i: Int) : I - val testI = 25 - val a = A(testI) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + val (_, env) = A(23).roundTrip() - val serSchema = obj.envelope.schema - - // The classes we're expecting to find: - // * class A - // * class A's interface (class I) - assertEquals(2, serSchema.types.size) - - val amqpSchema = serSchema.mangleNames(listOf(classTestName("A"), "${this.javaClass.`package`.name}.I")) - val aName = mangleName(classTestName("A")) - val iName = mangleName("${this.javaClass.`package`.name}.I") - val carpenterSchema = amqpSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - - // whilst there are two unknown classes within the envelope A depends on I so we can't construct a - // schema for A until we have for I - assertEquals(1, carpenterSchema.size) - assertNotEquals(null, carpenterSchema.carpenterSchemas.find { it.name == iName }) - - // since we can't build A it should list I as a dependency - assertTrue(aName in carpenterSchema.dependencies) - assertEquals(1, carpenterSchema.dependencies[aName]!!.second.size) - assertEquals(iName, carpenterSchema.dependencies[aName]!!.second[0]) - - // and conversly I should have A listed as a dependent - assertTrue(iName in carpenterSchema.dependsOn) - assertEquals(1, carpenterSchema.dependsOn[iName]!!.size) - assertEquals(aName, carpenterSchema.dependsOn[iName]!![0]) - - val mc = MetaCarpenter(carpenterSchema, ClassCarpenterImpl(whitelist = AllWhitelist)) - mc.build() - - assertEquals(0, mc.schemas.carpenterSchemas.size) - assertEquals(0, mc.schemas.dependencies.size) - assertEquals(0, mc.schemas.dependsOn.size) - assertEquals(2, mc.objects.size) - assertTrue(aName in mc.objects) - assertTrue(iName in mc.objects) - - mc.objects[aName]!!.constructors[0].newInstance(testI) + // This time around we will succeed, because the mangled I is included in the type information to be loaded. + assertCanLoadAll(env.getMangled().mangle(), env.getMangled()) } @Test fun twoInterfacesAndImplementation() { class A(override val i: Int, override val ii: Int) : I, II - val testI = 69 - val testII = 96 - val a = A(testI, testII) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) - - val amqpSchema = obj.envelope.schema.mangleNames(listOf( - classTestName("A"), - "${this.javaClass.`package`.name}.I", - "${this.javaClass.`package`.name}.II")) - - val aName = mangleName(classTestName("A")) - val iName = mangleName("${this.javaClass.`package`.name}.I") - val iiName = mangleName("${this.javaClass.`package`.name}.II") - val carpenterSchema = amqpSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - - // there is nothing preventing us from carpenting up the two interfaces so - // our initial list should contain both interface with A being dependent on both - // and each having A as a dependent - assertEquals(2, carpenterSchema.carpenterSchemas.size) - assertNotNull(carpenterSchema.carpenterSchemas.find { it.name == iName }) - assertNotNull(carpenterSchema.carpenterSchemas.find { it.name == iiName }) - assertNull(carpenterSchema.carpenterSchemas.find { it.name == aName }) - - assertTrue(iName in carpenterSchema.dependsOn) - assertEquals(1, carpenterSchema.dependsOn[iName]?.size) - assertNotNull(carpenterSchema.dependsOn[iName]?.find({ it == aName })) - - assertTrue(iiName in carpenterSchema.dependsOn) - assertEquals(1, carpenterSchema.dependsOn[iiName]?.size) - assertNotNull(carpenterSchema.dependsOn[iiName]?.find { it == aName }) - - assertTrue(aName in carpenterSchema.dependencies) - assertEquals(2, carpenterSchema.dependencies[aName]!!.second.size) - assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iName }) - assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iiName }) - - val mc = MetaCarpenter(carpenterSchema, ClassCarpenterImpl(whitelist = AllWhitelist)) - mc.build() - - assertEquals(0, mc.schemas.carpenterSchemas.size) - assertEquals(0, mc.schemas.dependencies.size) - assertEquals(0, mc.schemas.dependsOn.size) - assertEquals(3, mc.objects.size) - assertTrue(aName in mc.objects) - assertTrue(iName in mc.objects) - assertTrue(iiName in mc.objects) + val (_, env) = A(23, 42).roundTrip() + assertCanLoadAll( + env.getMangled().mangle().mangle(), + env.getMangled(), + env.getMangled() + ) } @Test fun nestedInterfacesAndImplementation() { class A(override val i: Int, override val iii: Int) : III - val testI = 7 - val testIII = 11 - val a = A(testI, testIII) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) - - val amqpSchema = obj.envelope.schema.mangleNames(listOf( - classTestName("A"), - "${this.javaClass.`package`.name}.I", - "${this.javaClass.`package`.name}.III")) - - val aName = mangleName(classTestName("A")) - val iName = mangleName("${this.javaClass.`package`.name}.I") - val iiiName = mangleName("${this.javaClass.`package`.name}.III") - val carpenterSchema = amqpSchema.carpenterSchema(ClassLoader.getSystemClassLoader()) - - // Since A depends on III and III extends I we will have to construct them - // in that reverse order (I -> III -> A) - assertEquals(1, carpenterSchema.carpenterSchemas.size) - assertNotNull(carpenterSchema.carpenterSchemas.find { it.name == iName }) - assertNull(carpenterSchema.carpenterSchemas.find { it.name == iiiName }) - assertNull(carpenterSchema.carpenterSchemas.find { it.name == aName }) - - // I has III as a direct dependent and A as an indirect one - assertTrue(iName in carpenterSchema.dependsOn) - assertEquals(2, carpenterSchema.dependsOn[iName]?.size) - assertNotNull(carpenterSchema.dependsOn[iName]?.find({ it == iiiName })) - assertNotNull(carpenterSchema.dependsOn[iName]?.find({ it == aName })) - - // III has A as a dependent - assertTrue(iiiName in carpenterSchema.dependsOn) - assertEquals(1, carpenterSchema.dependsOn[iiiName]?.size) - assertNotNull(carpenterSchema.dependsOn[iiiName]?.find { it == aName }) - - // conversly III depends on I - assertTrue(iiiName in carpenterSchema.dependencies) - assertEquals(1, carpenterSchema.dependencies[iiiName]!!.second.size) - assertNotNull(carpenterSchema.dependencies[iiiName]!!.second.find { it == iName }) - - // and A depends on III and I - assertTrue(aName in carpenterSchema.dependencies) - assertEquals(2, carpenterSchema.dependencies[aName]!!.second.size) - assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iiiName }) - assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iName }) - - val mc = MetaCarpenter(carpenterSchema, ClassCarpenterImpl(whitelist = AllWhitelist)) - mc.build() - - assertEquals(0, mc.schemas.carpenterSchemas.size) - assertEquals(0, mc.schemas.dependencies.size) - assertEquals(0, mc.schemas.dependsOn.size) - assertEquals(3, mc.objects.size) - assertTrue(aName in mc.objects) - assertTrue(iName in mc.objects) - assertTrue(iiiName in mc.objects) + val (_, env) = A(23, 42).roundTrip() + assertCanLoadAll( + env.getMangled().mangle().mangle(), + env.getMangled(), + env.getMangled().mangle() + ) } } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/MultiMemberCompositeSchemaToClassCarpenterTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/MultiMemberCompositeSchemaToClassCarpenterTests.kt index 3a560566bb..f5848a8f3b 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/MultiMemberCompositeSchemaToClassCarpenterTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/MultiMemberCompositeSchemaToClassCarpenterTests.kt @@ -1,57 +1,24 @@ package net.corda.serialization.internal.carpenter import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializableCalculatedProperty import net.corda.serialization.internal.AllWhitelist -import net.corda.serialization.internal.amqp.CompositeType -import net.corda.serialization.internal.amqp.DeserializationInput import org.junit.Test import kotlin.test.assertEquals import kotlin.test.assertNotEquals -import net.corda.serialization.internal.amqp.testutils.deserializeAndReturnEnvelope class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { @Test - fun twoInts() { + fun anIntAndALong() { @CordaSerializable - data class A(val a: Int, val b: Int) + data class A(val a: Int, val b: Long) - val testA = 10 - val testB = 20 - val a = A(testA, testB) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + val (_, env) = A(23, 42).roundTrip() + val carpentedInstance = env.getMangled().load().new(23, 42) - val amqpObj = obj.obj - - assertEquals(testA, amqpObj.a) - assertEquals(testB, amqpObj.b) - assertEquals(1, obj.envelope.schema.types.size) - require(obj.envelope.schema.types[0] is CompositeType) - - val amqpSchema = obj.envelope.schema.types[0] as CompositeType - - assertEquals(2, amqpSchema.fields.size) - assertEquals("a", amqpSchema.fields[0].name) - assertEquals("int", amqpSchema.fields[0].type) - assertEquals("b", amqpSchema.fields[1].name) - assertEquals("int", amqpSchema.fields[1].type) - - val carpenterSchema = CarpenterMetaSchema.newInstance() - amqpSchema.carpenterSchema( - classloader = ClassLoader.getSystemClassLoader(), - carpenterSchemas = carpenterSchema, - force = true) - - assertEquals(1, carpenterSchema.size) - val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") } - - assertNotEquals(null, aSchema) - - val pinochio = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema!!) - val p = pinochio.constructors[0].newInstance(testA, testB) - - assertEquals(pinochio.getMethod("getA").invoke(p), amqpObj.a) - assertEquals(pinochio.getMethod("getB").invoke(p), amqpObj.b) + assertEquals(23, carpentedInstance.get("a")) + assertEquals(42L, carpentedInstance.get("b")) } @Test @@ -59,42 +26,65 @@ class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhi @CordaSerializable data class A(val a: Int, val b: String) - val testA = 10 - val testB = "twenty" - val a = A(testA, testB) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + val (_, env) = A(23, "skidoo").roundTrip() + val carpentedInstance = env.getMangled().load().new(23, "skidoo") - val amqpObj = obj.obj + assertEquals(23, carpentedInstance.get("a")) + assertEquals("skidoo", carpentedInstance.get("b")) + } - assertEquals(testA, amqpObj.a) - assertEquals(testB, amqpObj.b) - assertEquals(1, obj.envelope.schema.types.size) - require(obj.envelope.schema.types[0] is CompositeType) + interface Parent { + @get:SerializableCalculatedProperty + val doubled: Int + } - val amqpSchema = obj.envelope.schema.types[0] as CompositeType + @Test + fun calculatedValues() { + data class C(val i: Int): Parent { + @get:SerializableCalculatedProperty + val squared = (i * i).toString() - assertEquals(2, amqpSchema.fields.size) - assertEquals("a", amqpSchema.fields[0].name) - assertEquals("int", amqpSchema.fields[0].type) - assertEquals("b", amqpSchema.fields[1].name) - assertEquals("string", amqpSchema.fields[1].type) + override val doubled get() = i * 2 + } - val carpenterSchema = CarpenterMetaSchema.newInstance() - amqpSchema.carpenterSchema( - classloader = ClassLoader.getSystemClassLoader(), - carpenterSchemas = carpenterSchema, - force = true) + val (amqpObj, envelope) = C(2).roundTrip() + val remoteTypeInformation = envelope.typeInformationFor() - assertEquals(1, carpenterSchema.size) - val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") } + assertEquals(""" + C: Parent + doubled: int + i: int + squared: String + """.trimIndent(), remoteTypeInformation.prettyPrint()) - assertNotEquals(null, aSchema) + val pinochio = remoteTypeInformation.mangle().load() + assertNotEquals(pinochio.name, C::class.java.name) + assertNotEquals(pinochio, C::class.java) - val pinochio = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema!!) - val p = pinochio.constructors[0].newInstance(testA, testB) + // Note that params are given in alphabetical order: doubled, i, squared + val p = pinochio.new(4, 2, "4") - assertEquals(pinochio.getMethod("getA").invoke(p), amqpObj.a) - assertEquals(pinochio.getMethod("getB").invoke(p), amqpObj.b) + assertEquals(2, p.get("i")) + assertEquals("4", p.get("squared")) + assertEquals(4, p.get("doubled")) + + val upcast = p as Parent + assertEquals(upcast.doubled, amqpObj.doubled) + } + + @Test + fun implementingClassDoesNotCalculateValue() { + class C(override val doubled: Int): Parent + + val (_, env) = C(5).roundTrip() + + val pinochio = env.getMangled().load() + val p = pinochio.new(5) + + assertEquals(5, p.get("doubled")) + + val upcast = p as Parent + assertEquals(5, upcast.doubled) } } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/SingleMemberCompositeSchemaToClassCarpenterTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/SingleMemberCompositeSchemaToClassCarpenterTests.kt deleted file mode 100644 index 5642b0b824..0000000000 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/SingleMemberCompositeSchemaToClassCarpenterTests.kt +++ /dev/null @@ -1,205 +0,0 @@ -package net.corda.serialization.internal.carpenter - -import net.corda.core.serialization.CordaSerializable -import net.corda.serialization.internal.AllWhitelist -import net.corda.serialization.internal.amqp.CompositeType -import net.corda.serialization.internal.amqp.DeserializationInput -import org.junit.Test -import kotlin.test.assertEquals -import net.corda.serialization.internal.amqp.testutils.deserializeAndReturnEnvelope - -class SingleMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { - @Test - fun singleInteger() { - @CordaSerializable - data class A(val a: Int) - - val test = 10 - val a = A(test) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) - val amqpObj = obj.obj - - assertEquals(test, amqpObj.a) - assertEquals(1, obj.envelope.schema.types.size) - require(obj.envelope.schema.types[0] is CompositeType) - - val amqpSchema = obj.envelope.schema.types[0] as CompositeType - - assertEquals(1, amqpSchema.fields.size) - assertEquals("a", amqpSchema.fields[0].name) - assertEquals("int", amqpSchema.fields[0].type) - - val carpenterSchema = CarpenterMetaSchema.newInstance() - amqpSchema.carpenterSchema( - classloader = ClassLoader.getSystemClassLoader(), - carpenterSchemas = carpenterSchema, - force = true) - - val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!! - val aBuilder = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema) - val p = aBuilder.constructors[0].newInstance(test) - - assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a) - } - - @Test - fun singleString() { - @CordaSerializable - data class A(val a: String) - - val test = "ten" - val a = A(test) - - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) - val amqpObj = obj.obj - - assertEquals(test, amqpObj.a) - assertEquals(1, obj.envelope.schema.types.size) - require(obj.envelope.schema.types[0] is CompositeType) - - val amqpSchema = obj.envelope.schema.types[0] as CompositeType - val carpenterSchema = CarpenterMetaSchema.newInstance() - amqpSchema.carpenterSchema( - classloader = ClassLoader.getSystemClassLoader(), - carpenterSchemas = carpenterSchema, - force = true) - - val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!! - val aBuilder = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema) - val p = aBuilder.constructors[0].newInstance(test) - - assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a) - } - - @Test - fun singleLong() { - @CordaSerializable - data class A(val a: Long) - - val test = 10L - val a = A(test) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) - val amqpObj = obj.obj - - assertEquals(test, amqpObj.a) - assertEquals(1, obj.envelope.schema.types.size) - require(obj.envelope.schema.types[0] is CompositeType) - - val amqpSchema = obj.envelope.schema.types[0] as CompositeType - - assertEquals(1, amqpSchema.fields.size) - assertEquals("a", amqpSchema.fields[0].name) - assertEquals("long", amqpSchema.fields[0].type) - - val carpenterSchema = CarpenterMetaSchema.newInstance() - amqpSchema.carpenterSchema( - classloader = ClassLoader.getSystemClassLoader(), - carpenterSchemas = carpenterSchema, - force = true) - - val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!! - val aBuilder = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema) - val p = aBuilder.constructors[0].newInstance(test) - - assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a) - } - - @Test - fun singleShort() { - @CordaSerializable - data class A(val a: Short) - - val test = 10.toShort() - val a = A(test) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) - val amqpObj = obj.obj - - assertEquals(test, amqpObj.a) - assertEquals(1, obj.envelope.schema.types.size) - require(obj.envelope.schema.types[0] is CompositeType) - - val amqpSchema = obj.envelope.schema.types[0] as CompositeType - - assertEquals(1, amqpSchema.fields.size) - assertEquals("a", amqpSchema.fields[0].name) - assertEquals("short", amqpSchema.fields[0].type) - - val carpenterSchema = CarpenterMetaSchema.newInstance() - amqpSchema.carpenterSchema( - classloader = ClassLoader.getSystemClassLoader(), - carpenterSchemas = carpenterSchema, - force = true) - - val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!! - val aBuilder = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema) - val p = aBuilder.constructors[0].newInstance(test) - - assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a) - } - - @Test - fun singleDouble() { - @CordaSerializable - data class A(val a: Double) - - val test = 10.0 - val a = A(test) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) - val amqpObj = obj.obj - - assertEquals(test, amqpObj.a) - assertEquals(1, obj.envelope.schema.types.size) - require(obj.envelope.schema.types[0] is CompositeType) - - val amqpSchema = obj.envelope.schema.types[0] as CompositeType - - assertEquals(1, amqpSchema.fields.size) - assertEquals("a", amqpSchema.fields[0].name) - assertEquals("double", amqpSchema.fields[0].type) - - val carpenterSchema = CarpenterMetaSchema.newInstance() - amqpSchema.carpenterSchema( - classloader = ClassLoader.getSystemClassLoader(), - carpenterSchemas = carpenterSchema, - force = true) - - val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!! - val aBuilder = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema) - val p = aBuilder.constructors[0].newInstance(test) - - assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a) - } - - @Test - fun singleFloat() { - @CordaSerializable - data class A(val a: Float) - - val test = 10.0F - val a = A(test) - val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) - val amqpObj = obj.obj - - assertEquals(test, amqpObj.a) - assertEquals(1, obj.envelope.schema.types.size) - require(obj.envelope.schema.types[0] is CompositeType) - - val amqpSchema = obj.envelope.schema.types[0] as CompositeType - - assertEquals(1, amqpSchema.fields.size) - assertEquals("a", amqpSchema.fields[0].name) - assertEquals("float", amqpSchema.fields[0].type) - - val carpenterSchema = CarpenterMetaSchema.newInstance() - amqpSchema.carpenterSchema( - classloader = ClassLoader.getSystemClassLoader(), - carpenterSchemas = carpenterSchema, - force = true) - - val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!! - val aBuilder = ClassCarpenterImpl(whitelist = AllWhitelist).build(aSchema) - val p = aBuilder.constructors[0].newInstance(test) - - assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a) - } -} diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/model/ClassCarpentingTypeLoaderTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/model/ClassCarpentingTypeLoaderTests.kt index 1e79409733..f7b942ccd1 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/model/ClassCarpentingTypeLoaderTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/model/ClassCarpentingTypeLoaderTests.kt @@ -54,7 +54,7 @@ class ClassCarpentingTypeLoaderTests { val person = personType.make("Arthur Putey", 42, address, listOf(previousAddress)) val personJson = ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(person) .replace("\r\n", "\n") - + assertEquals(""" { "name" : "Arthur Putey", diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/model/LocalTypeModelTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/model/LocalTypeModelTests.kt index d14646337f..7a10fdb560 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/model/LocalTypeModelTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/model/LocalTypeModelTests.kt @@ -59,9 +59,9 @@ class LocalTypeModelTests { assertInformation(""" Nested(collectionHolder: StringKeyedCollectionHolder?, intArray: int[], optionalParam: Short?) collectionHolder (optional): StringKeyedCollectionHolder(list: List, map: Map, array: List[]): CollectionHolder - array: List[] - list: List - map: Map + array: List[] + list: List + map: Map intArray: int[] """) diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/model/TypeIdentifierTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/model/TypeIdentifierTests.kt index f07f88526a..f4bc6b3f3a 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/model/TypeIdentifierTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/model/TypeIdentifierTests.kt @@ -43,6 +43,19 @@ class TypeIdentifierTests { TypeIdentifier.forGenericType(fieldType, HasStringArray::class.java).prettyPrint()) } + @Test + fun `roundtrip`() { + assertRoundtrips(Int::class.javaPrimitiveType!!) + assertRoundtrips() + assertRoundtrips() + assertRoundtrips(List::class.java) + assertRoundtrips>() + assertRoundtrips>>() + assertRoundtrips() + assertRoundtrips(HasArray::class.java) + assertRoundtrips>() + } + private fun assertIdentified(type: Type, expected: String) = assertEquals(expected, TypeIdentifier.forGenericType(type).prettyPrint()) @@ -50,4 +63,12 @@ class TypeIdentifierTests { assertEquals(expected, TypeIdentifier.forGenericType(typeOf()).prettyPrint()) private inline fun typeOf() = object : TypeToken() {}.type + + private inline fun assertRoundtrips() = assertRoundtrips(typeOf()) + + private fun assertRoundtrips(original: Type) { + val identifier = TypeIdentifier.forGenericType(original) + val localType = identifier.getLocalType(classLoader = ClassLoader.getSystemClassLoader()) + assertIdentified(localType, identifier.prettyPrint()) + } } \ No newline at end of file diff --git a/tools/shell/src/test/kotlin/net/corda/tools/shell/InteractiveShellTest.kt b/tools/shell/src/test/kotlin/net/corda/tools/shell/InteractiveShellTest.kt index 7fb3af428a..aed7852dd3 100644 --- a/tools/shell/src/test/kotlin/net/corda/tools/shell/InteractiveShellTest.kt +++ b/tools/shell/src/test/kotlin/net/corda/tools/shell/InteractiveShellTest.kt @@ -57,7 +57,7 @@ class InteractiveShellTest { }, input, FlowA::class.java, om) assertEquals(expected, output!!, input) } - + @Test fun flowStartSimple() { check("a: Hi there", "Hi there")