diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializerRegistry.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializerRegistry.kt index 61478dc43c..433b33b899 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializerRegistry.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/CustomSerializerRegistry.kt @@ -27,7 +27,18 @@ class CachingCustomSerializerRegistry( private data class CustomSerializerIdentifier(val actualTypeIdentifier: TypeIdentifier, val declaredTypeIdentifier: TypeIdentifier) - private val customSerializersCache: MutableMap> = DefaultCacheProvider.createCache() + private sealed class CustomSerializerLookupResult { + + abstract val serializerIfFound: AMQPSerializer? + + object None : CustomSerializerLookupResult() { + override val serializerIfFound: AMQPSerializer? = null + } + + data class CustomSerializerFound(override val serializerIfFound: AMQPSerializer) : CustomSerializerLookupResult() + } + + private val customSerializersCache: MutableMap = DefaultCacheProvider.createCache() private var customSerializers: List = emptyList() /** @@ -37,6 +48,11 @@ class CachingCustomSerializerRegistry( override fun register(customSerializer: CustomSerializer) { logger.trace("action=\"Registering custom serializer\", class=\"${customSerializer.type}\"") + if (!customSerializersCache.isEmpty()) { + logger.warn("Attempting to register custom serializer $customSerializer.type} in an active cache." + + "All serializers should be registered before the cache comes into use.") + } + descriptorBasedSerializerRegistry.getOrBuild(customSerializer.typeDescriptor.toString()) { customSerializers += customSerializer for (additional in customSerializer.additionalSerializers) { @@ -49,6 +65,11 @@ class CachingCustomSerializerRegistry( override fun registerExternal(customSerializer: CorDappCustomSerializer) { logger.trace("action=\"Registering external serializer\", class=\"${customSerializer.type}\"") + if (!customSerializersCache.isEmpty()) { + logger.warn("Attempting to register custom serializer ${customSerializer.type} in an active cache." + + "All serializers must be registered before the cache comes into use.") + } + descriptorBasedSerializerRegistry.getOrBuild(customSerializer.typeDescriptor.toString()) { customSerializers += customSerializer customSerializer @@ -60,10 +81,11 @@ class CachingCustomSerializerRegistry( TypeIdentifier.forClass(clazz), TypeIdentifier.forGenericType(declaredType)) - return customSerializersCache[typeIdentifier] - ?: doFindCustomSerializer(clazz, declaredType)?.also { serializer -> - customSerializersCache.putIfAbsent(typeIdentifier, serializer) - } + return customSerializersCache.getOrPut(typeIdentifier) { + val customSerializer = doFindCustomSerializer(clazz, declaredType) + if (customSerializer == null) CustomSerializerLookupResult.None + else CustomSerializerLookupResult.CustomSerializerFound(customSerializer) + }.serializerIfFound } private fun doFindCustomSerializer(clazz: Class<*>, declaredType: Type): AMQPSerializer? { diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/LocalSerializerFactory.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/LocalSerializerFactory.kt index 5a671d5373..173fb9ccec 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/LocalSerializerFactory.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/LocalSerializerFactory.kt @@ -7,10 +7,8 @@ import net.corda.core.utilities.debug import net.corda.core.utilities.trace import net.corda.serialization.internal.model.* import org.apache.qpid.proton.amqp.Symbol -import java.io.NotSerializableException import java.lang.reflect.ParameterizedType import java.lang.reflect.Type -import java.lang.reflect.WildcardType import java.util.* import javax.annotation.concurrent.ThreadSafe @@ -90,7 +88,10 @@ class DefaultLocalSerializerFactory( val logger = contextLogger() } - private val serializersByType: MutableMap> = DefaultCacheProvider.createCache() + private data class ActualAndDeclaredType(val actualType: Class<*>, val declaredType: Type) + + private val serializersByActualAndDeclaredType: MutableMap> = DefaultCacheProvider.createCache() + private val serializersByTypeId: MutableMap> = DefaultCacheProvider.createCache() private val typesByName = DefaultCacheProvider.createCache>() override fun createDescriptor(typeInformation: LocalTypeInformation): Symbol = @@ -101,10 +102,10 @@ class DefaultLocalSerializerFactory( override fun getTypeInformation(typeName: String): LocalTypeInformation? { return typesByName.getOrPut(typeName) { val localType = try { - Class.forName(typeName, false, classloader) - } catch (_: ClassNotFoundException) { - null - } + Class.forName(typeName, false, classloader) + } catch (_: ClassNotFoundException) { + null + } Optional.ofNullable(localType?.run { getTypeInformation(this) }) }.orElse(null) } @@ -112,73 +113,85 @@ class DefaultLocalSerializerFactory( override fun get(typeInformation: LocalTypeInformation): AMQPSerializer = get(typeInformation.observedType, typeInformation) - private fun make(typeInformation: LocalTypeInformation, build: () -> AMQPSerializer) = - make(typeInformation.typeIdentifier, build) + private fun makeAndCache(typeInformation: LocalTypeInformation, build: () -> AMQPSerializer) = + makeAndCache(typeInformation.typeIdentifier, build) - private fun make(typeIdentifier: TypeIdentifier, build: () -> AMQPSerializer) = - serializersByType.computeIfAbsent(typeIdentifier) { _ -> build() } - - private fun get(declaredType: Type, localTypeInformation: LocalTypeInformation): AMQPSerializer { - val declaredClass = declaredType.asClass() - - // can be useful to enable but will be *extremely* chatty if you do - logger.trace { "Get Serializer for $declaredClass ${declaredType.typeName}" } - - return when(localTypeInformation) { - is LocalTypeInformation.ACollection -> makeDeclaredCollection(localTypeInformation) - is LocalTypeInformation.AMap -> makeDeclaredMap(localTypeInformation) - is LocalTypeInformation.AnEnum -> makeDeclaredEnum(localTypeInformation, declaredType, declaredClass) - else -> makeClassSerializer(declaredClass, declaredType, declaredType, localTypeInformation) - }.also { serializer -> descriptorBasedSerializerRegistry[serializer.typeDescriptor.toString()] = serializer } - } - - private fun makeDeclaredEnum(localTypeInformation: LocalTypeInformation, declaredType: Type, declaredClass: Class<*>): AMQPSerializer = - make(localTypeInformation) { - whitelist.requireWhitelisted(declaredType) - EnumSerializer(declaredType, declaredClass, this) + private fun makeAndCache(typeIdentifier: TypeIdentifier, build: () -> AMQPSerializer) = + serializersByTypeId.getOrPut(typeIdentifier) { + build().also { serializer -> + descriptorBasedSerializerRegistry[serializer.typeDescriptor.toString()] = serializer + } } + private fun get(declaredType: Type, localTypeInformation: LocalTypeInformation): AMQPSerializer = + serializersByTypeId.getOrPut(localTypeInformation.typeIdentifier) { + val declaredClass = declaredType.asClass() + + // can be useful to enable but will be *extremely* chatty if you do + logger.trace { "Get Serializer for $declaredClass ${declaredType.typeName}" } + customSerializerRegistry.findCustomSerializer(declaredClass, declaredType)?.apply { return@get this } + + return when (localTypeInformation) { + is LocalTypeInformation.ACollection -> makeDeclaredCollection(localTypeInformation) + is LocalTypeInformation.AMap -> makeDeclaredMap(localTypeInformation) + is LocalTypeInformation.AnEnum -> makeDeclaredEnum(localTypeInformation, declaredType, declaredClass) + else -> makeClassSerializer(declaredClass, declaredType, localTypeInformation) + } + } + + private fun makeDeclaredEnum(localTypeInformation: LocalTypeInformation, declaredType: Type, declaredClass: Class<*>): AMQPSerializer = + makeAndCache(localTypeInformation) { + whitelist.requireWhitelisted(declaredType) + EnumSerializer(declaredType, declaredClass, this) + } + private fun makeActualEnum(localTypeInformation: LocalTypeInformation, declaredType: Type, declaredClass: Class<*>): AMQPSerializer = - make(localTypeInformation) { + makeAndCache(localTypeInformation) { whitelist.requireWhitelisted(declaredType) EnumSerializer(declaredType, declaredClass, this) } private fun makeDeclaredCollection(localTypeInformation: LocalTypeInformation.ACollection): AMQPSerializer { val resolved = CollectionSerializer.resolveDeclared(localTypeInformation) - return make(resolved) { + return makeAndCache(resolved) { CollectionSerializer(resolved.typeIdentifier.getLocalType(classloader) as ParameterizedType, this) } } private fun makeDeclaredMap(localTypeInformation: LocalTypeInformation.AMap): AMQPSerializer { val resolved = MapSerializer.resolveDeclared(localTypeInformation) - return make(resolved) { + return makeAndCache(resolved) { MapSerializer(resolved.typeIdentifier.getLocalType(classloader) as ParameterizedType, this) } } override fun get(actualClass: Class<*>, declaredType: Type): AMQPSerializer { - // can be useful to enable but will be *extremely* chatty if you do - logger.trace { "Get Serializer for $actualClass ${declaredType.typeName}" } + val actualAndDeclaredType = ActualAndDeclaredType(actualClass, declaredType) + return serializersByActualAndDeclaredType.getOrPut(actualAndDeclaredType) { + // can be useful to enable but will be *extremely* chatty if you do + logger.trace { "Get Serializer for $actualClass ${declaredType.typeName}" } + customSerializerRegistry.findCustomSerializer(actualClass, declaredType)?.apply { return@get this } - val declaredClass = declaredType.asClass() - val actualType: Type = inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType - val declaredTypeInformation = typeModel.inspect(declaredType) - val actualTypeInformation = typeModel.inspect(actualType) + val declaredClass = declaredType.asClass() + val actualType: Type = inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType + val declaredTypeInformation = typeModel.inspect(declaredType) + val actualTypeInformation = typeModel.inspect(actualType) - return when(actualTypeInformation) { - is LocalTypeInformation.ACollection -> makeActualCollection(actualClass,declaredTypeInformation as? LocalTypeInformation.ACollection ?: actualTypeInformation) - is LocalTypeInformation.AMap -> makeActualMap(declaredType, actualClass,declaredTypeInformation as? LocalTypeInformation.AMap ?: actualTypeInformation) - is LocalTypeInformation.AnEnum -> makeActualEnum(actualTypeInformation, actualType, actualClass) - else -> makeClassSerializer(actualClass, actualType, declaredType, actualTypeInformation) - }.also { serializer -> descriptorBasedSerializerRegistry[serializer.typeDescriptor.toString()] = serializer } + return when (actualTypeInformation) { + is LocalTypeInformation.ACollection -> makeActualCollection(actualClass, declaredTypeInformation as? LocalTypeInformation.ACollection + ?: actualTypeInformation) + is LocalTypeInformation.AMap -> makeActualMap(declaredType, actualClass, declaredTypeInformation as? LocalTypeInformation.AMap + ?: actualTypeInformation) + is LocalTypeInformation.AnEnum -> makeActualEnum(actualTypeInformation, actualType, actualClass) + else -> makeClassSerializer(actualClass, actualType, actualTypeInformation) + } + } } private fun makeActualMap(declaredType: Type, actualClass: Class<*>, typeInformation: LocalTypeInformation.AMap): AMQPSerializer { declaredType.asClass().checkSupportedMapType() val resolved = MapSerializer.resolveActual(actualClass, typeInformation) - return make(resolved) { + return makeAndCache(resolved) { MapSerializer(resolved.typeIdentifier.getLocalType(classloader) as ParameterizedType, this) } } @@ -186,7 +199,7 @@ class DefaultLocalSerializerFactory( private fun makeActualCollection(actualClass: Class<*>, typeInformation: LocalTypeInformation.ACollection): AMQPSerializer { val resolved = CollectionSerializer.resolveActual(actualClass, typeInformation) - return serializersByType.computeIfAbsent(resolved.typeIdentifier) { + return makeAndCache(resolved) { CollectionSerializer(resolved.typeIdentifier.getLocalType(classloader) as ParameterizedType, this) } } @@ -194,9 +207,8 @@ class DefaultLocalSerializerFactory( private fun makeClassSerializer( clazz: Class<*>, type: Type, - declaredType: Type, typeInformation: LocalTypeInformation - ): AMQPSerializer = make(typeInformation) { + ): AMQPSerializer = makeAndCache(typeInformation) { logger.debug { "class=${clazz.simpleName}, type=$type is a composite type" } when { clazz.isSynthetic -> // Explicitly ban synthetic classes, we have no way of recreating them when deserializing. This also @@ -205,8 +217,7 @@ class DefaultLocalSerializerFactory( type, "Serializer does not support synthetic classes") AMQPTypeIdentifiers.isPrimitive(typeInformation.typeIdentifier) -> AMQPPrimitiveSerializer(clazz) - else -> customSerializerRegistry.findCustomSerializer(clazz, declaredType) ?: - makeNonCustomSerializer(type, typeInformation, clazz) + else -> makeNonCustomSerializer(type, typeInformation, clazz) } }