diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPSerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPSerializationScheme.kt index 2f888c70fe..78f3a9752f 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPSerializationScheme.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPSerializationScheme.kt @@ -7,6 +7,7 @@ import net.corda.core.utilities.ByteSequence import net.corda.nodeapi.internal.serialization.DefaultWhitelist import net.corda.nodeapi.internal.serialization.MutableClassWhitelist import net.corda.nodeapi.internal.serialization.SerializationScheme +import java.security.PublicKey import java.util.* import java.util.concurrent.ConcurrentHashMap @@ -29,9 +30,9 @@ abstract class AbstractAMQPSerializationScheme : SerializationScheme { ServiceLoader.load(SerializationWhitelist::class.java, this::class.java.classLoader).toList() + DefaultWhitelist } - fun registerCustomSerializers(factory: SerializerFactory) { + fun registerCustomSerializers(factory: SerializerFactory, publicKeySerializer: CustomSerializer.Implements = net.corda.nodeapi.internal.serialization.amqp.custom.PublicKeySerializer) { with(factory) { - register(net.corda.nodeapi.internal.serialization.amqp.custom.PublicKeySerializer) + register(publicKeySerializer) register(net.corda.nodeapi.internal.serialization.amqp.custom.PrivateKeySerializer) register(net.corda.nodeapi.internal.serialization.amqp.custom.ThrowableSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.X500NameSerializer) @@ -69,6 +70,7 @@ abstract class AbstractAMQPSerializationScheme : SerializationScheme { protected abstract fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory protected abstract fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory + open protected val publicKeySerializer = net.corda.nodeapi.internal.serialization.amqp.custom.PublicKeySerializer private fun getSerializerFactory(context: SerializationContext): SerializerFactory { return serializerFactoriesForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { @@ -81,7 +83,7 @@ abstract class AbstractAMQPSerializationScheme : SerializationScheme { rpcServerSerializerFactory(context) else -> SerializerFactory(context.whitelist, context.deserializationClassLoader) } - }.also { registerCustomSerializers(it) } + }.also { registerCustomSerializers(it, publicKeySerializer) } } override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T {