diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/MapSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/MapSerializer.kt index 314c9b6aa9..de803afc37 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/MapSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/MapSerializer.kt @@ -10,6 +10,8 @@ import kotlin.collections.Map import kotlin.collections.iterator import kotlin.collections.map +private typealias MapCreationFunction = (Map<*, *>) -> Map<*, *> + /** * Serialization / deserialization of certain supported [Map] types. */ @@ -18,7 +20,8 @@ class MapSerializer(private val declaredType: ParameterizedType, factory: Serial override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}" companion object { - private val supportedTypes: Map, (Map<*, *>) -> Map<*, *>> = mapOf( + // NB: Order matters in this map, the most specific classes should be listed at the end + private val supportedTypes: Map>, MapCreationFunction> = Collections.unmodifiableMap(linkedMapOf( // Interfaces Map::class.java to { map -> Collections.unmodifiableMap(map) }, SortedMap::class.java to { map -> Collections.unmodifiableSortedMap(TreeMap(map)) }, @@ -26,13 +29,36 @@ class MapSerializer(private val declaredType: ParameterizedType, factory: Serial // concrete classes for user convenience LinkedHashMap::class.java to { map -> LinkedHashMap(map) }, TreeMap::class.java to { map -> TreeMap(map) } - ) - private fun findConcreteType(clazz: Class<*>): (Map<*, *>) -> Map<*, *> { + )) + + private fun findConcreteType(clazz: Class<*>): MapCreationFunction { return supportedTypes[clazz] ?: throw NotSerializableException("Unsupported map type $clazz.") } + + fun deriveParameterizedType(declaredType: Type, declaredClass: Class<*>, actualClass: Class<*>?): ParameterizedType { + if(supportedTypes.containsKey(declaredClass)) { + // Simple case - it is already known to be a map. + @Suppress("UNCHECKED_CAST") + return deriveParametrizedType(declaredType, declaredClass as Class>) + } + else if (actualClass != null && Map::class.java.isAssignableFrom(actualClass)) { + // Declared class is not map, but [actualClass] is - represent it accordingly. + val mapClass = findMostSuitableMapType(actualClass) + return deriveParametrizedType(declaredType, mapClass) + } + + throw NotSerializableException("Cannot derive map type for declaredType: '$declaredType', declaredClass: '$declaredClass', actualClass: '$actualClass'") + } + + private fun deriveParametrizedType(declaredType: Type, collectionClass: Class>): ParameterizedType = + (declaredType as? ParameterizedType) ?: DeserializedParameterizedType(collectionClass, arrayOf(SerializerFactory.AnyType, SerializerFactory.AnyType)) + + + private fun findMostSuitableMapType(actualClass: Class<*>): Class> = + MapSerializer.supportedTypes.keys.findLast { it.isAssignableFrom(actualClass) }!! } - private val concreteBuilder: (Map<*, *>) -> Map<*, *> = findConcreteType(declaredType.rawType as Class<*>) + private val concreteBuilder: MapCreationFunction = findConcreteType(declaredType.rawType as Class<*>) private val typeNotation: TypeNotation = RestrictedType(SerializerFactory.nameForType(declaredType), null, emptyList(), "map", Descriptor(typeDescriptor, null), emptyList()) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt index e98704aa1d..408c6400b9 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt @@ -4,7 +4,6 @@ import com.google.common.primitives.Primitives import com.google.common.reflect.TypeResolver import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.CordaSerializable -import net.corda.nodeapi.internal.serialization.amqp.CollectionSerializer.Companion.deriveParameterizedType import net.corda.nodeapi.internal.serialization.carpenter.* import org.apache.qpid.proton.amqp.* import java.io.NotSerializableException @@ -70,20 +69,23 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) { val actualType: Type = inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType val serializer = when { + // Declared class may not be set to Collection, but actual class could be a collection. + // In this case use of CollectionSerializer is perfectly appropriate. (Collection::class.java.isAssignableFrom(declaredClass) || - // declared class may not be set to Collection, but actual class could be a collection. - // In this case use of CollectionSerializer is perfectly appropriate. - (actualClass != null && Collection::class.java.isAssignableFrom(actualClass))) -> { - - val declaredTypeAmended= deriveParameterizedType(declaredType, declaredClass, actualClass) - - serializersByType.computeIfAbsent(declaredTypeAmended) { - CollectionSerializer(declaredTypeAmended, this) - } + (actualClass != null && Collection::class.java.isAssignableFrom(actualClass))) -> { + val declaredTypeAmended= CollectionSerializer.deriveParameterizedType(declaredType, declaredClass, actualClass) + serializersByType.computeIfAbsent(declaredTypeAmended) { + CollectionSerializer(declaredTypeAmended, this) + } } - Map::class.java.isAssignableFrom(declaredClass) -> serializersByType.computeIfAbsent(declaredClass) { - makeMapSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType( - declaredClass, arrayOf(AnyType, AnyType), null)) + // Declared class may not be set to Map, but actual class could be a map. + // In this case use of MapSerializer is perfectly appropriate. + (Map::class.java.isAssignableFrom(declaredClass) || + (actualClass != null && Map::class.java.isAssignableFrom(actualClass))) -> { + val declaredTypeAmended= MapSerializer.deriveParameterizedType(declaredType, declaredClass, actualClass) + serializersByType.computeIfAbsent(declaredClass) { + makeMapSerializer(declaredTypeAmended) + } } Enum::class.java.isAssignableFrom(declaredClass) -> serializersByType.computeIfAbsent(declaredClass) { EnumSerializer(actualType, actualClass ?: declaredClass, this) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt index a5e83e0c01..2d01f3466c 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt @@ -16,6 +16,7 @@ class ListsSerializationTest : TestDependencyInjectionBase() { @Test fun `check list can be serialized as root of serialization graph`() { + assertEqualAfterRoundTripSerialization(emptyList()) assertEqualAfterRoundTripSerialization(listOf(1)) assertEqualAfterRoundTripSerialization(listOf(1, 2)) } @@ -46,12 +47,12 @@ class ListsSerializationTest : TestDependencyInjectionBase() { Assertions.assertThatThrownBy { wrongPayloadType.serialize() } .isInstanceOf(NotSerializableException::class.java).hasMessageContaining("Cannot derive collection type for declaredType") } +} - private inline fun assertEqualAfterRoundTripSerialization(obj: T) { +internal inline fun assertEqualAfterRoundTripSerialization(obj: T) { - val serializedForm: SerializedBytes = obj.serialize() - val deserializedInstance = serializedForm.deserialize() + val serializedForm: SerializedBytes = obj.serialize() + val deserializedInstance = serializedForm.deserialize() - assertEquals(obj, deserializedInstance) - } + assertEquals(obj, deserializedInstance) } \ No newline at end of file diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt new file mode 100644 index 0000000000..78e0e10c31 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt @@ -0,0 +1,57 @@ +package net.corda.nodeapi.internal.serialization + +import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.serialize +import net.corda.node.services.statemachine.SessionData +import net.corda.testing.TestDependencyInjectionBase +import net.corda.testing.amqpSpecific +import org.assertj.core.api.Assertions +import org.junit.Test +import org.bouncycastle.asn1.x500.X500Name +import java.io.NotSerializableException + +class MapsSerializationTest : TestDependencyInjectionBase() { + + private val smallMap = mapOf("foo" to "bar", "buzz" to "bull") + + @Test + fun `check EmptyMap serialization`() = amqpSpecific("kotlin.collections.EmptyMap is not enabled for Kryo serialization") { + assertEqualAfterRoundTripSerialization(emptyMap()) + } + + @Test + fun `check Map can be root of serialization graph`() { + assertEqualAfterRoundTripSerialization(smallMap) + } + + @Test + fun `check list can be serialized as part of SessionData`() { + val sessionData = SessionData(123, smallMap) + assertEqualAfterRoundTripSerialization(sessionData) + } + + @CordaSerializable + data class WrongPayloadType(val payload: HashMap) + + @Test + fun `check throws for forbidden declared type`() = amqpSpecific("Such exceptions are not expected in Kryo mode.") { + val payload = HashMap(smallMap) + val wrongPayloadType = WrongPayloadType(payload) + Assertions.assertThatThrownBy { wrongPayloadType.serialize() } + .isInstanceOf(NotSerializableException::class.java).hasMessageContaining("Cannot derive map type for declaredType") + } + + @CordaSerializable + data class MyKey(val keyContent: Double) + + @CordaSerializable + data class MyValue(val valueContent: X500Name) + + @Test + fun `check map serialization works with custom types`() { + val myMap = mapOf( + MyKey(1.0) to MyValue(X500Name("CN=one")), + MyKey(10.0) to MyValue(X500Name("CN=ten"))) + assertEqualAfterRoundTripSerialization(myMap) + } +} \ No newline at end of file diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeMapTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeMapTests.kt index 162756b06c..37dfb8d3f4 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeMapTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeMapTests.kt @@ -1,6 +1,8 @@ package net.corda.nodeapi.internal.serialization.amqp +import org.assertj.core.api.Assertions import org.junit.Test +import java.io.NotSerializableException import java.util.* class DeserializeCollectionTests { @@ -11,7 +13,7 @@ class DeserializeCollectionTests { private const val VERBOSE = false } - val sf = testDefaultFactory() + private val sf = testDefaultFactory() @Test fun mapTest() { @@ -57,7 +59,7 @@ class DeserializeCollectionTests { DeserializationInput(sf).deserialize(serialisedBytes) } - @Test(expected=java.io.NotSerializableException::class) + @Test fun dictionaryTest() { data class C(val c: Dictionary) val v : Hashtable = Hashtable() @@ -66,10 +68,11 @@ class DeserializeCollectionTests { val c = C(v) // expected to throw - TestSerializationOutput(VERBOSE, sf).serialize(c) + Assertions.assertThatThrownBy { TestSerializationOutput(VERBOSE, sf).serialize(c) } + .isInstanceOf(IllegalArgumentException::class.java).hasMessageContaining("Unable to serialise deprecated type class java.util.Dictionary.") } - @Test(expected=java.lang.IllegalArgumentException::class) + @Test fun hashtableTest() { data class C(val c: Hashtable) val v : Hashtable = Hashtable() @@ -78,24 +81,27 @@ class DeserializeCollectionTests { val c = C(v) // expected to throw - TestSerializationOutput(VERBOSE, sf).serialize(c) + Assertions.assertThatThrownBy { TestSerializationOutput(VERBOSE, sf).serialize(c) } + .isInstanceOf(NotSerializableException::class.java).hasMessageContaining("Cannot derive map type for declaredType") } - @Test(expected=java.lang.IllegalArgumentException::class) + @Test fun hashMapTest() { data class C(val c : HashMap) val c = C (HashMap (mapOf("A" to 1, "B" to 2))) // expect this to throw - TestSerializationOutput(VERBOSE, sf).serialize(c) + Assertions.assertThatThrownBy { TestSerializationOutput(VERBOSE, sf).serialize(c) } + .isInstanceOf(NotSerializableException::class.java).hasMessageContaining("Cannot derive map type for declaredType") } - @Test(expected=java.lang.IllegalArgumentException::class) + @Test fun weakHashMapTest() { data class C(val c : WeakHashMap) val c = C (WeakHashMap (mapOf("A" to 1, "B" to 2))) - TestSerializationOutput(VERBOSE, sf).serialize(c) + Assertions.assertThatThrownBy { TestSerializationOutput(VERBOSE, sf).serialize(c) } + .isInstanceOf(NotSerializableException::class.java).hasMessageContaining("Cannot derive map type for declaredType") } @Test