diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt index 40b557f8f4..126b9af1d1 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt @@ -58,25 +58,31 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { */ @Throws(NotSerializableException::class) fun get(actualClass: Class<*>?, declaredType: Type): AMQPSerializer { - val declaredClass = declaredType.asClass() - if (declaredClass != null) { - val actualType: Type = inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType + val declaredClass = declaredType.asClass() ?: throw NotSerializableException( + "Declared types of $declaredType are not supported.") + + val actualType: Type = inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType + + val rtn = let { if (Collection::class.java.isAssignableFrom(declaredClass)) { - return serializersByType.computeIfAbsent(declaredType) { - CollectionSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType(declaredClass, arrayOf(AnyType), null), this) + serializersByType.computeIfAbsent(declaredType) { + CollectionSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType( + declaredClass, arrayOf(AnyType), null), this) } } else if (Map::class.java.isAssignableFrom(declaredClass)) { - return serializersByType.computeIfAbsent(declaredClass) { - makeMapSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType(declaredClass, arrayOf(AnyType, AnyType), null)) + serializersByType.computeIfAbsent(declaredClass) { + makeMapSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType( + declaredClass, arrayOf(AnyType, AnyType), null)) } } else { - return makeClassSerializer(actualClass ?: declaredClass, actualType, declaredType) + makeClassSerializer(actualClass ?: declaredClass, actualType, declaredType) } - } else { - throw NotSerializableException("Declared types of $declaredType are not supported.") } - } + serializersByDescriptor.putIfAbsent(rtn.typeDescriptor, rtn) + + return rtn + } /** * Try and infer concrete types for any generics type variables for the actual class encountered, based on the declared @@ -177,11 +183,10 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { for (typeNotation in schema.types) { try { processSchemaEntry(typeNotation, classCarpenter.classloader) - } - catch (e: ClassNotFoundException) { + } catch (e: ClassNotFoundException) { if (sentinal || (typeNotation !is CompositeType)) throw e typeNotation.carpenterSchema( - classLoaders = listOf (classCarpenter.classloader), carpenterSchemas = carpenterSchemas) + classLoaders = listOf(classCarpenter.classloader), carpenterSchemas = carpenterSchemas) } } @@ -193,7 +198,7 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { } private fun processSchemaEntry(typeNotation: TypeNotation, - cl: ClassLoader = DeserializedParameterizedType::class.java.classLoader) { + cl: ClassLoader = DeserializedParameterizedType::class.java.classLoader) { when (typeNotation) { is CompositeType -> processCompositeType(typeNotation, cl) // java.lang.Class (whether a class or interface) is RestrictedType -> processRestrictedType(typeNotation) // Collection / Map, possibly with generics @@ -201,24 +206,19 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { } private fun processRestrictedType(typeNotation: RestrictedType) { - serializersByDescriptor.computeIfAbsent(typeNotation.descriptor.name!!) { - // TODO: class loader logic, and compare the schema. - val type = typeForName(typeNotation.name) - get(null, type) - } + // TODO: class loader logic, and compare the schema. + val type = typeForName(typeNotation.name) + get(null, type) } private fun processCompositeType(typeNotation: CompositeType, - cl: ClassLoader = DeserializedParameterizedType::class.java.classLoader) { - serializersByDescriptor.computeIfAbsent(typeNotation.descriptor.name!!) { - // TODO: class loader logic, and compare the schema. - val type = typeForName(typeNotation.name, cl) - get(type.asClass() ?: throw NotSerializableException("Unable to build composite type for $type"), type) - } + cl: ClassLoader = DeserializedParameterizedType::class.java.classLoader) { + // TODO: class loader logic, and compare the schema. + val type = typeForName(typeNotation.name, cl) + get(type.asClass() ?: throw NotSerializableException("Unable to build composite type for $type"), type) } - private fun makeClassSerializer(clazz: Class<*>, type: Type, declaredType: Type): AMQPSerializer = - serializersByType.computeIfAbsent(type) { + private fun makeClassSerializer(clazz: Class<*>, type: Type, declaredType: Type): AMQPSerializer = serializersByType.computeIfAbsent(type) { if (isPrimitive(clazz)) { AMQPPrimitiveSerializer(clazz) } else { @@ -226,7 +226,7 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { if (type.isArray()) { whitelisted(type.componentType()) if (clazz.componentType.isPrimitive) PrimArraySerializer.make(type, this) - else ArraySerializer.make (type, this) + else ArraySerializer.make(type, this) } else if (clazz.kotlin.objectInstance != null) { whitelisted(clazz) SingletonSerializer(clazz, clazz.kotlin.objectInstance!!, this) @@ -236,7 +236,7 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { } } } - } + } internal fun findCustomSerializer(clazz: Class<*>, declaredType: Type): AMQPSerializer? { // e.g. Imagine if we provided a Map serializer this way, then it won't work if the declared type is AbstractMap, only Map. @@ -313,15 +313,15 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { private val namesOfPrimitiveTypes: Map> = primitiveTypeNames.map { it.value to it.key }.toMap() - fun nameForType(type: Type) : String = when (type) { - is Class<*> -> { - primitiveTypeName(type) ?: if (type.isArray) { - "${nameForType(type.componentType)}${if(type.componentType.isPrimitive) "[p]" else "[]"}" - } else type.name - } - is ParameterizedType -> "${nameForType(type.rawType)}<${type.actualTypeArguments.joinToString { nameForType(it) }}>" - is GenericArrayType -> "${nameForType(type.genericComponentType)}[]" - else -> throw NotSerializableException("Unable to render type $type to a string.") + fun nameForType(type: Type): String = when (type) { + is Class<*> -> { + primitiveTypeName(type) ?: if (type.isArray) { + "${nameForType(type.componentType)}${if (type.componentType.isPrimitive) "[p]" else "[]"}" + } else type.name + } + is ParameterizedType -> "${nameForType(type.rawType)}<${type.actualTypeArguments.joinToString { nameForType(it) }}>" + is GenericArrayType -> "${nameForType(type.genericComponentType)}[]" + else -> throw NotSerializableException("Unable to render type $type to a string.") } private fun typeForName( @@ -340,7 +340,7 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { // There is no need to handle the ByteArray case as that type is coercible automatically // to the binary type and is thus handled by the main serializer and doesn't need a // special case for a primitive array of bytes - when(name) { + when (name) { "int[p]" -> IntArray::class.java "char[p]" -> CharArray::class.java "boolean[p]" -> BooleanArray::class.java diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentrySimpleTypesTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentrySimpleTypesTest.kt index 1a5cf62c2c..76933fb370 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentrySimpleTypesTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentrySimpleTypesTest.kt @@ -4,12 +4,10 @@ import org.junit.Test import kotlin.test.* import net.corda.nodeapi.internal.serialization.carpenter.* -/** - * These tests work by having the class carpenter build the classes we serialise and then deserialise. Because - * those classes don't exist within the system's Class Loader the deserialiser will be forced to carpent - * versions of them up using its own internal class carpenter (each carpenter houses it's own loader). This - * replicates the situation where a receiver doesn't have some or all elements of a schema present on it's classpath - */ +// These tests work by having the class carpenter build the classes we serialise and then deserialise. Because +// those classes don't exist within the system's Class Loader the deserialiser will be forced to carpent +// versions of them up using its own internal class carpenter (each carpenter houses it's own loader). This +// replicates the situation where a receiver doesn't have some or all elements of a schema present on it's classpath class DeserializeNeedingCarpentrySimpleTypesTest { companion object { /** @@ -18,7 +16,8 @@ class DeserializeNeedingCarpentrySimpleTypesTest { private const val VERBOSE = false } - val sf = SerializerFactory() + val sf = SerializerFactory() + val sf2 = SerializerFactory() @Test fun singleInt() { @@ -28,8 +27,16 @@ class DeserializeNeedingCarpentrySimpleTypesTest { val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(1)) val db = DeserializationInput(sf).deserialize(sb) + val db2 = DeserializationInput(sf2).deserialize(sb) + + // despite being carpented, and thus not on the class path, we should've cached clazz + // inside the serialiser object and thus we should have created the same type + assertEquals (db::class.java, clazz) + assertNotEquals (db2::class.java, clazz) + assertNotEquals (db::class.java, db2::class.java) assertEquals(1, db::class.java.getMethod("getInt").invoke(db)) + assertEquals(1, db2::class.java.getMethod("getInt").invoke(db2)) } @Test @@ -39,9 +46,13 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(1)) - val db = DeserializationInput(sf).deserialize(sb) + val db1 = DeserializationInput(sf).deserialize(sb) + val db2 = DeserializationInput(sf2).deserialize(sb) - assertEquals(1, db::class.java.getMethod("getInt").invoke(db)) + assertEquals(clazz, db1::class.java) + assertNotEquals(clazz, db2::class.java) + assertEquals(1, db1::class.java.getMethod("getInt").invoke(db1)) + assertEquals(1, db2::class.java.getMethod("getInt").invoke(db2)) } @Test @@ -51,9 +62,13 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) - val db = DeserializationInput(sf).deserialize(sb) + val db1 = DeserializationInput(sf).deserialize(sb) + val db2 = DeserializationInput(sf2).deserialize(sb) - assertEquals(null, db::class.java.getMethod("getInt").invoke(db)) + assertEquals(clazz, db1::class.java) + assertNotEquals(clazz, db2::class.java) + assertEquals(null, db1::class.java.getMethod("getInt").invoke(db1)) + assertEquals(null, db2::class.java.getMethod("getInt").invoke(db2)) } @Test @@ -63,8 +78,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance('a')) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals('a', db::class.java.getMethod("getChar").invoke(db)) } @@ -75,8 +91,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance('a')) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals('a', db::class.java.getMethod("getChar").invoke(db)) } @@ -87,8 +104,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(null, db::class.java.getMethod("getChar").invoke(db)) } @@ -100,8 +118,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { val l : Long = 1 val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(l)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(l, (db::class.java.getMethod("getLong").invoke(db))) } @@ -113,8 +132,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { val l : Long = 1 val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(l)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(l, (db::class.java.getMethod("getLong").invoke(db))) } @@ -125,8 +145,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(null, (db::class.java.getMethod("getLong").invoke(db))) } @@ -137,8 +158,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(true)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(true, db::class.java.getMethod("getBoolean").invoke(db)) } @@ -149,8 +171,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(true)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(true, db::class.java.getMethod("getBoolean").invoke(db)) } @@ -161,8 +184,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(null, db::class.java.getMethod("getBoolean").invoke(db)) } @@ -173,8 +197,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(10.0)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(10.0, db::class.java.getMethod("getDouble").invoke(db)) } @@ -185,8 +210,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(10.0)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(10.0, db::class.java.getMethod("getDouble").invoke(db)) } @@ -197,8 +223,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(null, db::class.java.getMethod("getDouble").invoke(db)) } @@ -209,8 +236,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(3.toShort())) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(3.toShort(), db::class.java.getMethod("getShort").invoke(db)) } @@ -221,8 +249,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(3.toShort())) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(3.toShort(), db::class.java.getMethod("getShort").invoke(db)) } @@ -233,8 +262,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(null, db::class.java.getMethod("getShort").invoke(db)) } @@ -245,8 +275,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(10.0F)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(10.0F, db::class.java.getMethod("getFloat").invoke(db)) } @@ -257,8 +288,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(10.0F)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(10.0F, db::class.java.getMethod("getFloat").invoke(db)) } @@ -269,8 +301,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(null, db::class.java.getMethod("getFloat").invoke(db)) } @@ -282,8 +315,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { val b : Byte = 0b0101 val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(b)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(b, db::class.java.getMethod("getByte").invoke(db)) assertEquals(0b0101, (db::class.java.getMethod("getByte").invoke(db) as Byte)) } @@ -296,8 +330,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { val b : Byte = 0b0101 val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(b)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(b, db::class.java.getMethod("getByte").invoke(db)) assertEquals(0b0101, (db::class.java.getMethod("getByte").invoke(db) as Byte)) } @@ -309,8 +344,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { ))) val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) - val db = DeserializationInput(sf).deserialize(sb) + val db = DeserializationInput(sf2).deserialize(sb) + assertNotEquals(clazz, db::class.java) assertEquals(null, db::class.java.getMethod("getByte").invoke(db)) } @@ -323,8 +359,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { val classInstance = clazz.constructors[0].newInstance(testVal) val serialisedBytes = TestSerializationOutput(VERBOSE, sf).serialize(classInstance) - val deserializedObj = DeserializationInput(sf).deserialize(serialisedBytes) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + assertNotEquals(clazz, deserializedObj::class.java) assertTrue(deserializedObj is I) assertEquals(testVal, (deserializedObj as I).getName()) } @@ -372,8 +409,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest { 10.0F, 20.0F, null, 0b0101.toByte(), 0b1010.toByte(), null)) - val deserializedObj = DeserializationInput(sf).deserialize(serialisedBytes) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + assertNotEquals(manyClass, deserializedObj::class.java) assertEquals(1, deserializedObj::class.java.getMethod("getIntA").invoke(deserializedObj)) assertEquals(2, deserializedObj::class.java.getMethod("getIntB").invoke(deserializedObj)) assertEquals(null, deserializedObj::class.java.getMethod("getIntC").invoke(deserializedObj)) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentryTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentryTests.kt index f1a44aef7b..d0cae9e3b6 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentryTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentryTests.kt @@ -9,9 +9,13 @@ interface I { } /** - * These tests work by having the class carpenter build the classes we serialise and then deserialise. Because - * those classes don't exist within the system's Class Loader the deserialiser will be forced to carpent - * versions of them up using its own internal class carpenter (each carpenter houses it's own loader). This + * These tests work by having the class carpenter build the classes we serialise and then deserialise them + * within the context of a second serialiser factory. The second factory is required as the first, having + * been used to serialise the class, will have cached a copy of the class and will thus bypass the need + * to pull it out of the class loader. + * + * However, those classes don't exist within the system's Class Loader and thus the deserialiser will be forced + * to carpent versions of them up using its own internal class carpenter (each carpenter houses it's own loader). This * replicates the situation where a receiver doesn't have some or all elements of a schema present on it's classpath */ class DeserializeNeedingCarpentryTests { @@ -22,7 +26,8 @@ class DeserializeNeedingCarpentryTests { private const val VERBOSE = false } - val sf = SerializerFactory() + val sf1 = SerializerFactory() + val sf2 = SerializerFactory() @Test fun verySimpleType() { @@ -30,15 +35,30 @@ class DeserializeNeedingCarpentryTests { val clazz = ClassCarpenter().build(ClassSchema("oneType", mapOf("a" to NonNullableField(Int::class.java)))) val classInstance = clazz.constructors[0].newInstance(testVal) - val serialisedBytes = TestSerializationOutput(VERBOSE, sf).serialize(classInstance) - val deserializedObj = DeserializationInput(sf).deserialize(serialisedBytes) + val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance) - assertNotEquals(clazz::class.java, deserializedObj::class.java) - assertEquals (testVal, deserializedObj::class.java.getMethod("getA").invoke(deserializedObj)) + val deserializedObj1 = DeserializationInput(sf1).deserialize(serialisedBytes) + assertEquals(clazz, deserializedObj1::class.java) + assertEquals (testVal, deserializedObj1::class.java.getMethod("getA").invoke(deserializedObj1)) - val deserializedObj2 = DeserializationInput(sf).deserialize(serialisedBytes) + val deserializedObj2 = DeserializationInput(sf1).deserialize(serialisedBytes) + assertEquals(clazz, deserializedObj2::class.java) + assertEquals(deserializedObj1::class.java, deserializedObj2::class.java) + assertEquals (testVal, deserializedObj2::class.java.getMethod("getA").invoke(deserializedObj2)) + + val deserializedObj3 = DeserializationInput(sf2).deserialize(serialisedBytes) + assertNotEquals(clazz, deserializedObj3::class.java) + assertNotEquals(deserializedObj1::class.java, deserializedObj3::class.java) + assertNotEquals(deserializedObj2::class.java, deserializedObj3::class.java) + assertEquals (testVal, deserializedObj3::class.java.getMethod("getA").invoke(deserializedObj3)) + + val deserializedObj4 = DeserializationInput(sf2).deserialize(serialisedBytes) + assertNotEquals(clazz, deserializedObj4::class.java) + assertNotEquals(deserializedObj1::class.java, deserializedObj4::class.java) + assertNotEquals(deserializedObj2::class.java, deserializedObj4::class.java) + assertEquals(deserializedObj3::class.java, deserializedObj4::class.java) + assertEquals (testVal, deserializedObj4::class.java.getMethod("getA").invoke(deserializedObj4)) - assertEquals(deserializedObj::class.java, deserializedObj2::class.java) } @Test @@ -51,11 +71,11 @@ class DeserializeNeedingCarpentryTests { val concreteB = clazz.constructors[0].newInstance(testValB) val concreteC = clazz.constructors[0].newInstance(testValC) - val deserialisedA = DeserializationInput(sf).deserialize(TestSerializationOutput(VERBOSE, sf).serialize(concreteA)) + val deserialisedA = DeserializationInput(sf2).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(concreteA)) assertEquals (testValA, deserialisedA::class.java.getMethod("getA").invoke(deserialisedA)) - val deserialisedB = DeserializationInput(sf).deserialize(TestSerializationOutput(VERBOSE, sf).serialize(concreteB)) + val deserialisedB = DeserializationInput(sf2).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(concreteB)) assertEquals (testValB, deserialisedA::class.java.getMethod("getA").invoke(deserialisedB)) assertEquals (deserialisedA::class.java, deserialisedB::class.java) @@ -79,8 +99,8 @@ class DeserializeNeedingCarpentryTests { val testVal = "Some Person" val classInstance = clazz.constructors[0].newInstance(testVal) - val serialisedBytes = TestSerializationOutput(VERBOSE, sf).serialize(classInstance) - val deserializedObj = DeserializationInput(sf).deserialize(serialisedBytes) + val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) assertTrue(deserializedObj is I) assertEquals(testVal, (deserializedObj as I).getName()) @@ -97,7 +117,7 @@ class DeserializeNeedingCarpentryTests { clazz.constructors[0].newInstance(2), clazz.constructors[0].newInstance(3))) - val deserializedObj = DeserializationInput(sf).deserialize(TestSerializationOutput(VERBOSE, sf).serialize(outer)) + val deserializedObj = DeserializationInput(sf2).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(outer)) assertNotEquals((deserializedObj.a[0])::class.java, (outer.a[0])::class.java) assertNotEquals((deserializedObj.a[1])::class.java, (outer.a[1])::class.java) @@ -127,10 +147,10 @@ class DeserializeNeedingCarpentryTests { val inner = innerType.constructors[0].newInstance(1) val outer = outerType.constructors[0].newInstance(innerType.constructors[0].newInstance(2)) - val serializedI = TestSerializationOutput(VERBOSE, sf).serialize(inner) - val deserialisedI = DeserializationInput(sf).deserialize(serializedI) - val serialisedO = TestSerializationOutput(VERBOSE, sf).serialize(outer) - val deserialisedO = DeserializationInput(sf).deserialize(serialisedO) + val serializedI = TestSerializationOutput(VERBOSE, sf1).serialize(inner) + val deserialisedI = DeserializationInput(sf2).deserialize(serializedI) + val serialisedO = TestSerializationOutput(VERBOSE, sf1).serialize(outer) + val deserialisedO = DeserializationInput(sf2).deserialize(serialisedO) // ensure out carpented version of inner is reused assertEquals (deserialisedI::class.java, @@ -147,8 +167,8 @@ class DeserializeNeedingCarpentryTests { mapOf("inner" to NonNullableField(nestedClass)))) val classInstance = outerClass.constructors.first().newInstance(nestedClass.constructors.first().newInstance("name")) - val serialisedBytes = TestSerializationOutput(VERBOSE, sf).serialize(classInstance) - val deserializedObj = DeserializationInput(sf).deserialize(serialisedBytes) + val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) val inner = deserializedObj::class.java.getMethod("getInner").invoke(deserializedObj) assertEquals("name", inner::class.java.getMethod("getName").invoke(inner)) @@ -166,8 +186,8 @@ class DeserializeNeedingCarpentryTests { nestedClass.constructors.first().newInstance("foo"), nestedClass.constructors.first().newInstance("bar")) - val serialisedBytes = TestSerializationOutput(VERBOSE, sf).serialize(classInstance) - val deserializedObj = DeserializationInput(sf).deserialize(serialisedBytes) + val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) assertEquals ("foo", deserializedObj.a::class.java.getMethod("getName").invoke(deserializedObj.a)) assertEquals ("bar", deserializedObj.b::class.java.getMethod("getName").invoke(deserializedObj.b)) @@ -186,8 +206,8 @@ class DeserializeNeedingCarpentryTests { unknownClass.constructors.first().newInstance(5, 6), unknownClass.constructors.first().newInstance(7, 8))) - val serialisedBytes = TestSerializationOutput(VERBOSE, sf).serialize(toSerialise) - val deserializedObj = DeserializationInput(sf).deserialize(serialisedBytes) + val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(toSerialise) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) var sentinel = 1 deserializedObj.l.forEach { assertEquals(sentinel++, it::class.java.getMethod("getV1").invoke(it)) @@ -208,9 +228,9 @@ class DeserializeNeedingCarpentryTests { "name" to NonNullableField(String::class.java)), interfaces = listOf (I::class.java, interfaceClass))) - val serialisedBytes = TestSerializationOutput(VERBOSE, sf).serialize( + val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize( concreteClass.constructors.first().newInstance(12, "timmy")) - val deserializedObj = DeserializationInput(sf).deserialize(serialisedBytes) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) assertTrue(deserializedObj is I) assertEquals("timmy", (deserializedObj as I).getName()) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeSimpleTypesTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeSimpleTypesTests.kt index 61628705af..1e7171b31a 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeSimpleTypesTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeSimpleTypesTests.kt @@ -16,7 +16,8 @@ class DeserializeSimpleTypesTests { private const val VERBOSE = false } - val sf = SerializerFactory() + val sf1 = SerializerFactory() + val sf2 = SerializerFactory() @Test fun testChar() { @@ -74,8 +75,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [Ljava.lang.Integer;", ia.ia::class.java.toString()) assertEquals(SerializerFactory.nameForType(ia.ia::class.java), "int[]") - val serialisedIA = TestSerializationOutput(VERBOSE, sf).serialize(ia) - val deserializedIA = DeserializationInput(sf).deserialize(serialisedIA) + val serialisedIA = TestSerializationOutput(VERBOSE, sf1).serialize(ia) + val deserializedIA = DeserializationInput(sf1).deserialize(serialisedIA) assertEquals(ia.ia.size, deserializedIA.ia.size) assertEquals(ia.ia[0], deserializedIA.ia[0]) @@ -93,8 +94,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [Ljava.lang.Integer;", ia.ia::class.java.toString()) assertEquals(SerializerFactory.nameForType(ia.ia::class.java), "int[]") - val serialisedIA = TestSerializationOutput(VERBOSE, sf).serialize(ia) - val deserializedIA = DeserializationInput(sf).deserialize(serialisedIA) + val serialisedIA = TestSerializationOutput(VERBOSE, sf1).serialize(ia) + val deserializedIA = DeserializationInput(sf1).deserialize(serialisedIA) assertEquals(ia.ia.size, deserializedIA.ia.size) assertEquals(ia.ia[0], deserializedIA.ia[0]) @@ -116,8 +117,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [I", ia.ia::class.java.toString()) assertEquals(SerializerFactory.nameForType(ia.ia::class.java), "int[p]") - val serialisedIA = TestSerializationOutput(VERBOSE, sf).serialize(ia) - val deserializedIA = DeserializationInput(sf).deserialize(serialisedIA) + val serialisedIA = TestSerializationOutput(VERBOSE, sf1).serialize(ia) + val deserializedIA = DeserializationInput(sf1).deserialize(serialisedIA) assertEquals(ia.ia.size, deserializedIA.ia.size) assertEquals(ia.ia[0], deserializedIA.ia[0]) @@ -134,8 +135,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [Ljava.lang.Character;", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "char[]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -154,8 +155,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [C", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "char[p]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - var deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + var deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -166,7 +167,7 @@ class DeserializeSimpleTypesTests { v[0] = 'ই'; v[1] = ' '; v[2] = 'ਔ' val c2 = C(v) - deserializedC = DeserializationInput(sf).deserialize(TestSerializationOutput(VERBOSE, sf).serialize(c2)) + deserializedC = DeserializationInput(sf1).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(c2)) assertEquals(c2.c.size, deserializedC.c.size) assertEquals(c2.c[0], deserializedC.c[0]) @@ -183,8 +184,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [Ljava.lang.Boolean;", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "boolean[]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -203,8 +204,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [Z", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "boolean[p]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -222,8 +223,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [Ljava.lang.Byte;", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "byte[]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -241,8 +242,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [B", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "binary") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -259,8 +260,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [Ljava.lang.Short;", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "short[]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -278,8 +279,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [S", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "short[p]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -296,8 +297,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [Ljava.lang.Long;", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "long[]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -315,8 +316,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [J", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "long[p]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -333,8 +334,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [Ljava.lang.Float;", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "float[]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -352,8 +353,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [F", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "float[p]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -370,8 +371,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [Ljava.lang.Double;", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "double[]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -389,8 +390,8 @@ class DeserializeSimpleTypesTests { assertEquals("class [D", c.c::class.java.toString()) assertEquals(SerializerFactory.nameForType(c.c::class.java), "double[p]") - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0], deserializedC.c[0]) @@ -403,8 +404,8 @@ class DeserializeSimpleTypesTests { class C(val c: Array>) val c = C (arrayOf (arrayOf(1,2,3), arrayOf(4,5,6))) - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0].size, deserializedC.c[0].size) @@ -424,8 +425,8 @@ class DeserializeSimpleTypesTests { c.c[0][0] = 1; c.c[0][1] = 2; c.c[0][2] = 3 c.c[1][0] = 4; c.c[1][1] = 5; c.c[1][2] = 6 - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) assertEquals(c.c.size, deserializedC.c.size) assertEquals(c.c[0].size, deserializedC.c[0].size) @@ -448,12 +449,36 @@ class DeserializeSimpleTypesTests { for (i in 0..2) { for (j in 0..2) { for (k in 0..2) { c.c[i][j][k] = i + j + k } } } - val serialisedC = TestSerializationOutput(VERBOSE, sf).serialize(c) - val deserializedC = DeserializationInput(sf).deserialize(serialisedC) + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) for (i in 0..2) { for (j in 0..2) { for (k in 0..2) { assertEquals(c.c[i][j][k], deserializedC.c[i][j][k]) }}} } + + @Test + fun nestedRepeatedTypes() { + class A(val a : A?, val b: Int) + + var a = A(A(A(A(A(null, 1), 2), 3), 4), 5) + + val sa = TestSerializationOutput(VERBOSE, sf1).serialize(a) + val da1 = DeserializationInput(sf1).deserialize(sa) + val da2 = DeserializationInput(sf2).deserialize(sa) + + assertEquals(5, da1.b) + assertEquals(4, da1.a?.b) + assertEquals(3, da1.a?.a?.b) + assertEquals(2, da1.a?.a?.a?.b) + assertEquals(1, da1.a?.a?.a?.a?.b) + + assertEquals(5, da2.b) + assertEquals(4, da2.a?.b) + assertEquals(3, da2.a?.a?.b) + assertEquals(2, da2.a?.a?.a?.b) + assertEquals(1, da2.a?.a?.a?.a?.b) + + } }