Merge pull request #2024 from corda/christians_public-key-caching-hooks

Public key customization hooks
This commit is contained in:
Christian Sailer
2017-11-10 13:37:37 +00:00
committed by GitHub
8 changed files with 131 additions and 45 deletions

3
.idea/compiler.xml generated
View File

@ -40,6 +40,9 @@
<module name="docs_source_example-code_main" target="1.8" /> <module name="docs_source_example-code_main" target="1.8" />
<module name="docs_source_example-code_test" target="1.8" /> <module name="docs_source_example-code_test" target="1.8" />
<module name="docs_test" target="1.8" /> <module name="docs_test" target="1.8" />
<module name="example-code_integrationTest" target="1.8" />
<module name="example-code_main" target="1.8" />
<module name="example-code_test" target="1.8" />
<module name="experimental_main" target="1.8" /> <module name="experimental_main" target="1.8" />
<module name="experimental_test" target="1.8" /> <module name="experimental_test" target="1.8" />
<module name="explorer-capsule_main" target="1.6" /> <module name="explorer-capsule_main" target="1.6" />

View File

@ -19,7 +19,7 @@ class KryoClientSerializationScheme : AbstractKryoSerializationScheme() {
override fun rpcClientKryoPool(context: SerializationContext): KryoPool { override fun rpcClientKryoPool(context: SerializationContext): KryoPool {
return KryoPool.Builder { return KryoPool.Builder {
DefaultKryoCustomizer.customize(RPCKryo(RpcClientObservableSerializer, context)).apply { DefaultKryoCustomizer.customize(RPCKryo(RpcClientObservableSerializer, context), publicKeySerializer).apply {
classLoader = context.deserializationClassLoader classLoader = context.deserializationClassLoader
} }
}.build() }.build()

View File

@ -7,6 +7,7 @@ import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.internal.serialization.DefaultWhitelist import net.corda.nodeapi.internal.serialization.DefaultWhitelist
import net.corda.nodeapi.internal.serialization.MutableClassWhitelist import net.corda.nodeapi.internal.serialization.MutableClassWhitelist
import net.corda.nodeapi.internal.serialization.SerializationScheme import net.corda.nodeapi.internal.serialization.SerializationScheme
import java.security.PublicKey
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
@ -28,47 +29,49 @@ abstract class AbstractAMQPSerializationScheme : SerializationScheme {
private val serializationWhitelists: List<SerializationWhitelist> by lazy { private val serializationWhitelists: List<SerializationWhitelist> by lazy {
ServiceLoader.load(SerializationWhitelist::class.java, this::class.java.classLoader).toList() + DefaultWhitelist ServiceLoader.load(SerializationWhitelist::class.java, this::class.java.classLoader).toList() + DefaultWhitelist
} }
}
fun registerCustomSerializers(factory: SerializerFactory) { private fun registerCustomSerializers(factory: SerializerFactory) {
with(factory) { with(factory) {
register(net.corda.nodeapi.internal.serialization.amqp.custom.PublicKeySerializer) register(publicKeySerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.PrivateKeySerializer) register(net.corda.nodeapi.internal.serialization.amqp.custom.PrivateKeySerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.ThrowableSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.ThrowableSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.X500NameSerializer) register(net.corda.nodeapi.internal.serialization.amqp.custom.X500NameSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.BigDecimalSerializer) register(net.corda.nodeapi.internal.serialization.amqp.custom.BigDecimalSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.CurrencySerializer) register(net.corda.nodeapi.internal.serialization.amqp.custom.CurrencySerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.OpaqueBytesSubSequenceSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.OpaqueBytesSubSequenceSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.InstantSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.InstantSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.DurationSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.DurationSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.LocalDateSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.LocalDateSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.LocalDateTimeSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.LocalDateTimeSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.LocalTimeSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.LocalTimeSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.ZonedDateTimeSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.ZonedDateTimeSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.ZoneIdSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.ZoneIdSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.OffsetTimeSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.OffsetTimeSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.OffsetDateTimeSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.OffsetDateTimeSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.YearSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.YearSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.YearMonthSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.YearMonthSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.MonthDaySerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.MonthDaySerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.PeriodSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.PeriodSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.ClassSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.ClassSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.X509CertificateHolderSerializer) register(net.corda.nodeapi.internal.serialization.amqp.custom.X509CertificateHolderSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.PartyAndCertificateSerializer(factory)) register(net.corda.nodeapi.internal.serialization.amqp.custom.PartyAndCertificateSerializer(factory))
register(net.corda.nodeapi.internal.serialization.amqp.custom.StringBufferSerializer) register(net.corda.nodeapi.internal.serialization.amqp.custom.StringBufferSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.SimpleStringSerializer) register(net.corda.nodeapi.internal.serialization.amqp.custom.SimpleStringSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.InputStreamSerializer) register(net.corda.nodeapi.internal.serialization.amqp.custom.InputStreamSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.BitSetSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.BitSetSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.EnumSetSerializer(this)) register(net.corda.nodeapi.internal.serialization.amqp.custom.EnumSetSerializer(this))
}
for (whitelistProvider in serializationWhitelists)
factory.addToWhitelist(*whitelistProvider.whitelist.toTypedArray())
} }
for (whitelistProvider in serializationWhitelists)
factory.addToWhitelist(*whitelistProvider.whitelist.toTypedArray())
} }
private val serializerFactoriesForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, SerializerFactory>() private val serializerFactoriesForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, SerializerFactory>()
protected abstract fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory protected abstract fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory
protected abstract fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory protected abstract fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory
open protected val publicKeySerializer: CustomSerializer.Implements<PublicKey>
= net.corda.nodeapi.internal.serialization.amqp.custom.PublicKeySerializer
private fun getSerializerFactory(context: SerializationContext): SerializerFactory { private fun getSerializerFactory(context: SerializationContext): SerializerFactory {
return serializerFactoriesForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { return serializerFactoriesForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) {

View File

@ -51,6 +51,7 @@ import java.io.ByteArrayOutputStream
import java.io.FileInputStream import java.io.FileInputStream
import java.io.InputStream import java.io.InputStream
import java.lang.reflect.Modifier.isPublic import java.lang.reflect.Modifier.isPublic
import java.security.PublicKey
import java.security.cert.CertPath import java.security.cert.CertPath
import java.util.* import java.util.*
import kotlin.collections.ArrayList import kotlin.collections.ArrayList
@ -60,7 +61,7 @@ object DefaultKryoCustomizer {
ServiceLoader.load(SerializationWhitelist::class.java, this.javaClass.classLoader).toList() + DefaultWhitelist ServiceLoader.load(SerializationWhitelist::class.java, this.javaClass.classLoader).toList() + DefaultWhitelist
} }
fun customize(kryo: Kryo): Kryo { fun customize(kryo: Kryo, publicKeySerializer: Serializer<PublicKey> = PublicKeySerializer): Kryo {
return kryo.apply { return kryo.apply {
// Store a little schema of field names in the stream the first time a class is used which increases tolerance // Store a little schema of field names in the stream the first time a class is used which increases tolerance
// for change to a class. // for change to a class.
@ -95,10 +96,10 @@ object DefaultKryoCustomizer {
register(BufferedInputStream::class.java, InputStreamSerializer) register(BufferedInputStream::class.java, InputStreamSerializer)
register(Class.forName("sun.net.www.protocol.jar.JarURLConnection\$JarURLInputStream"), InputStreamSerializer) register(Class.forName("sun.net.www.protocol.jar.JarURLConnection\$JarURLInputStream"), InputStreamSerializer)
noReferencesWithin<WireTransaction>() noReferencesWithin<WireTransaction>()
register(ECPublicKeyImpl::class.java, PublicKeySerializer) register(ECPublicKeyImpl::class.java, publicKeySerializer)
register(EdDSAPublicKey::class.java, PublicKeySerializer) register(EdDSAPublicKey::class.java, publicKeySerializer)
register(EdDSAPrivateKey::class.java, PrivateKeySerializer) register(EdDSAPrivateKey::class.java, PrivateKeySerializer)
register(CompositeKey::class.java, PublicKeySerializer) // Using a custom serializer for compactness register(CompositeKey::class.java, publicKeySerializer) // Using a custom serializer for compactness
// Exceptions. We don't bother sending the stack traces as the client will fill in its own anyway. // Exceptions. We don't bother sending the stack traces as the client will fill in its own anyway.
register(Array<StackTraceElement>::class, read = { _, _ -> emptyArray() }, write = { _, _, _ -> }) register(Array<StackTraceElement>::class, read = { _, _ -> emptyArray() }, write = { _, _, _ -> })
// This ensures a NonEmptySetSerializer is constructed with an initial value. // This ensures a NonEmptySetSerializer is constructed with an initial value.
@ -111,11 +112,11 @@ object DefaultKryoCustomizer {
register(X500Name::class.java, X500NameSerializer) register(X500Name::class.java, X500NameSerializer)
register(X509CertificateHolder::class.java, X509CertificateSerializer) register(X509CertificateHolder::class.java, X509CertificateSerializer)
register(BCECPrivateKey::class.java, PrivateKeySerializer) register(BCECPrivateKey::class.java, PrivateKeySerializer)
register(BCECPublicKey::class.java, PublicKeySerializer) register(BCECPublicKey::class.java, publicKeySerializer)
register(BCRSAPrivateCrtKey::class.java, PrivateKeySerializer) register(BCRSAPrivateCrtKey::class.java, PrivateKeySerializer)
register(BCRSAPublicKey::class.java, PublicKeySerializer) register(BCRSAPublicKey::class.java, publicKeySerializer)
register(BCSphincs256PrivateKey::class.java, PrivateKeySerializer) register(BCSphincs256PrivateKey::class.java, PrivateKeySerializer)
register(BCSphincs256PublicKey::class.java, PublicKeySerializer) register(BCSphincs256PublicKey::class.java, publicKeySerializer)
register(NotaryChangeWireTransaction::class.java, NotaryChangeWireTransactionSerializer) register(NotaryChangeWireTransaction::class.java, NotaryChangeWireTransactionSerializer)
register(PartyAndCertificate::class.java, PartyAndCertificateSerializer) register(PartyAndCertificate::class.java, PartyAndCertificateSerializer)

View File

@ -18,6 +18,7 @@ import net.corda.core.serialization.*
import net.corda.core.internal.LazyPool import net.corda.core.internal.LazyPool
import net.corda.nodeapi.internal.serialization.CordaClassResolver import net.corda.nodeapi.internal.serialization.CordaClassResolver
import net.corda.nodeapi.internal.serialization.SerializationScheme import net.corda.nodeapi.internal.serialization.SerializationScheme
import java.security.PublicKey
// "corda" + majorVersionByte + minorVersionMSB + minorVersionLSB // "corda" + majorVersionByte + minorVersionMSB + minorVersionLSB
val KryoHeaderV0_1: OpaqueBytes = OpaqueBytes("corda\u0000\u0000\u0001".toByteArray(Charsets.UTF_8)) val KryoHeaderV0_1: OpaqueBytes = OpaqueBytes("corda\u0000\u0000\u0001".toByteArray(Charsets.UTF_8))
@ -39,6 +40,9 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool
protected abstract fun rpcServerKryoPool(context: SerializationContext): KryoPool protected abstract fun rpcServerKryoPool(context: SerializationContext): KryoPool
// this can be overriden in derived serialization schemes
open protected val publicKeySerializer: Serializer<PublicKey> = PublicKeySerializer
private fun getPool(context: SerializationContext): KryoPool { private fun getPool(context: SerializationContext): KryoPool {
return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) {
when (context.useCase) { when (context.useCase) {
@ -50,6 +54,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true } val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true }
serializer.kryo.apply { serializer.kryo.apply {
field.set(this, classResolver) field.set(this, classResolver)
// don't allow overriding the public key serializer for checkpointing
DefaultKryoCustomizer.customize(this) DefaultKryoCustomizer.customize(this)
addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector) addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector)
register(ClosureSerializer.Closure::class.java, CordaClosureSerializer) register(ClosureSerializer.Closure::class.java, CordaClosureSerializer)
@ -62,7 +67,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
rpcServerKryoPool(context) rpcServerKryoPool(context)
else -> else ->
KryoPool.Builder { KryoPool.Builder {
DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(context))).apply { classLoader = it.second } DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(context)), publicKeySerializer).apply { classLoader = it.second }
}.build() }.build()
} }
} }

View File

@ -0,0 +1,65 @@
package net.corda.nodeapi.internal.serialization.amqp
import net.corda.core.serialization.SerializationContext
import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.internal.serialization.AMQP_P2P_CONTEXT
import org.apache.qpid.proton.codec.Data
import org.assertj.core.api.Assertions
import org.junit.Test
import java.lang.reflect.Type
import java.security.PublicKey
class OverridePKSerializerTest {
class SerializerTestException(message: String) : Exception(message)
class TestPublicKeySerializer : CustomSerializer.Implements<PublicKey>(PublicKey::class.java) {
override fun writeDescribedObject(obj: PublicKey, data: Data, type: Type, output: SerializationOutput) {
throw SerializerTestException("Custom write call")
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): PublicKey {
throw SerializerTestException("Custom read call")
}
override val schemaForDocumentation: Schema
get() = TODO("not implemented") //To change initializer of created properties use File | Settings | File Templates.
}
class AMQPTestSerializationScheme : AbstractAMQPSerializationScheme() {
override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean = true
override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override val publicKeySerializer = TestPublicKeySerializer()
}
class TestPublicKey : PublicKey {
override fun getAlgorithm(): String {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun getEncoded(): ByteArray {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun getFormat(): String {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
}
}
@Test
fun `test publicKeySerializer is overridden`() {
val scheme = AMQPTestSerializationScheme()
val key = TestPublicKey()
Assertions
.assertThatThrownBy { scheme.serialize(key, AMQP_P2P_CONTEXT) }
.hasMessageMatching("Custom write call")
}
}

View File

@ -36,6 +36,10 @@ import java.nio.ByteBuffer
import java.time.* import java.time.*
import java.time.temporal.ChronoUnit import java.time.temporal.ChronoUnit
import java.util.* import java.util.*
import kotlin.reflect.full.declaredFunctions
import kotlin.reflect.full.declaredMemberFunctions
import kotlin.reflect.full.superclasses
import kotlin.reflect.jvm.javaMethod
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertNotNull import kotlin.test.assertNotNull
import kotlin.test.assertTrue import kotlin.test.assertTrue
@ -557,11 +561,16 @@ class SerializationOutputTests {
fun `test transaction state`() { fun `test transaction state`() {
val state = TransactionState(FooState(), FOO_PROGRAM_ID, MEGA_CORP) val state = TransactionState(FooState(), FOO_PROGRAM_ID, MEGA_CORP)
val scheme = AMQPServerSerializationScheme()
val func = scheme::class.superclasses.single { it.simpleName == "AbstractAMQPSerializationScheme" }
.java.getDeclaredMethod("registerCustomSerializers", SerializerFactory::class.java)
func.isAccessible = true
val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
AbstractAMQPSerializationScheme.registerCustomSerializers(factory) func.invoke(scheme, factory)
val factory2 = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader()) val factory2 = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
AbstractAMQPSerializationScheme.registerCustomSerializers(factory2) func.invoke(scheme, factory2)
val desState = serdes(state, factory, factory2, expectedEqual = false, expectDeserializedEqual = false) val desState = serdes(state, factory, factory2, expectedEqual = false, expectDeserializedEqual = false)
assertTrue((desState as TransactionState<*>).data is FooState) assertTrue((desState as TransactionState<*>).data is FooState)

View File

@ -18,7 +18,7 @@ class KryoServerSerializationScheme : AbstractKryoSerializationScheme() {
override fun rpcServerKryoPool(context: SerializationContext): KryoPool { override fun rpcServerKryoPool(context: SerializationContext): KryoPool {
return KryoPool.Builder { return KryoPool.Builder {
DefaultKryoCustomizer.customize(RPCKryo(RpcServerObservableSerializer, context)).apply { DefaultKryoCustomizer.customize(RPCKryo(RpcServerObservableSerializer, context), publicKeySerializer).apply {
classLoader = context.deserializationClassLoader classLoader = context.deserializationClassLoader
} }
}.build() }.build()