From 1fb1d6fb72b0ca6608d97d6b4aa3f601463bee4e Mon Sep 17 00:00:00 2001 From: Katelyn Baker Date: Mon, 11 Dec 2017 18:32:12 +0000 Subject: [PATCH] CORDA-852 - Fix AMQP serialisation of nested generic --- .../carpenter/AMQPSchemaExtensions.kt | 3 +- .../serialization/amqp/GenericsTests.kt | 114 ++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/GenericsTests.kt diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/AMQPSchemaExtensions.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/AMQPSchemaExtensions.kt index 6a0c3c9784..6117695b77 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/AMQPSchemaExtensions.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/AMQPSchemaExtensions.kt @@ -122,7 +122,8 @@ val typeStrToType: Map, Class> = mapOf( fun AMQPField.getTypeAsClass(classloader: ClassLoader) = typeStrToType[Pair(type, mandatory)] ?: when (type) { "string" -> String::class.java - "*" -> classloader.loadClass(requires[0]) + "binary" -> ByteArray::class.java + "*" -> if (requires.isEmpty()) Any::class.java else classloader.loadClass(requires[0]) else -> classloader.loadClass(type) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/GenericsTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/GenericsTests.kt new file mode 100644 index 0000000000..9882c79b41 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/GenericsTests.kt @@ -0,0 +1,114 @@ +package net.corda.nodeapi.internal.serialization.amqp + +import net.corda.core.serialization.SerializedBytes +import net.corda.nodeapi.internal.serialization.AllWhitelist +import org.junit.Test +import kotlin.test.assertEquals + +class GenericsTests { + + @Test + fun nestedSerializationOfGenerics() { + data class G(val a: T) + data class Wrapper(val a: Int, val b: G) + + val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) + val altContextFactory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) + val ser = SerializationOutput(factory) + + val bytes = ser.serializeAndReturnSchema(G("hi")) + + assertEquals("hi", DeserializationInput(factory).deserialize(bytes.obj).a) + assertEquals("hi", DeserializationInput(altContextFactory).deserialize(bytes.obj).a) + + val bytes2 = ser.serializeAndReturnSchema(Wrapper(1, G("hi"))) + + DeserializationInput(factory).deserialize(bytes2.obj).apply { + assertEquals(1, a) + assertEquals("hi", b.a) + } + + DeserializationInput(altContextFactory).deserialize(bytes2.obj).apply { + assertEquals(1, a) + assertEquals("hi", b.a) + } + } + + @Test + fun nestedGenericsReferencesByteArrayViaSerializedBytes() { + data class G(val a : Int) + data class Wrapper(val a: Int, val b: SerializedBytes) + + val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) + val factory2 = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) + val ser = SerializationOutput(factory) + + val gBytes = ser.serialize(G(1)) + val bytes2 = ser.serializeAndReturnSchema(Wrapper(1, gBytes)) + + DeserializationInput(factory).deserialize(bytes2.obj).apply { + assertEquals(1, a) + assertEquals(1, DeserializationInput(factory).deserialize(b).a) + } + DeserializationInput(factory2).deserialize(bytes2.obj).apply { + assertEquals(1, a) + assertEquals(1, DeserializationInput(factory).deserialize(b).a) + } + } + + @Test + fun nestedSerializationInMultipleContextsDoesntColideGenericTypes() { + data class InnerA(val a_a: Int) + data class InnerB(val a_b: Int) + data class InnerC(val a_c: String) + data class Container(val b: T) + data class Wrapper(val c: Container) + + val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) + val factories = listOf(factory, SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())) + val ser = SerializationOutput(factory) + + ser.serialize(Wrapper(Container(InnerA(1)))).apply { + factories.forEach { + DeserializationInput(it).deserialize(this).apply { assertEquals(1, c.b.a_a) } + } + } + + ser.serialize(Wrapper(Container(InnerB(1)))).apply { + factories.forEach { + DeserializationInput(it).deserialize(this).apply { assertEquals(1, c.b.a_b) } + } + } + + ser.serialize(Wrapper(Container(InnerC("Ho ho ho")))).apply { + factories.forEach { + DeserializationInput(it).deserialize(this).apply { assertEquals("Ho ho ho", c.b.a_c) } + } + } + } + + @Test + fun nestedSerializationWhereGenericDoesntImpactFingerprint() { + data class Inner(val a : Int) + data class Container(val b: Inner) + data class Wrapper(val c: Container) + + val factorys = listOf( + SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()), + SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())) + + val ser = SerializationOutput(factorys[0]) + + ser.serialize(Wrapper(Container(Inner(1)))).apply { + factorys.forEach { + assertEquals(1, DeserializationInput(it).deserialize(this).c.b.a) + } + } + + ser.serialize(Wrapper(Container(Inner(1)))).apply { + factorys.forEach { + assertEquals(1, DeserializationInput(it).deserialize(this).c.b.a) + } + } + } +} \ No newline at end of file