From 532d95ccac76d2bfef250f2596641949c6435154 Mon Sep 17 00:00:00 2001 From: Christian Sailer Date: Mon, 1 Oct 2018 13:59:52 +0100 Subject: [PATCH] ENT-1565 Enable the use of BoringSSL (#1358) * BoringSsl dependency * Merge over boring_ssl changes * Merge over boring_ssl changes * Upgrade netty-tcnative (and netty to compatible version) * Add openSSL flag to SSLConfiguration and implementations. * Make SSL implementation switchable for Artemis * Parameterize AMQP bridge tests on use of openSSL * Plumb through open SSL flag to AMQP client/server. * Add open ssl flag to reference.conf * Slight clean-up * Add LoggingTrustManagerWrapper for OpenSsl contexts * Remove unneeded lazy and check for double wrapping * Fix TrustMangerWrapper and test, clean-up * Add key factory wrapper to get the current certificate chain out. * Use cert chain returning key mananager factory to get local cert * Force consistent netty-tcnative version across all dependencies * Make proton wrapper tests check all combinations of client/server native/java SSL * Add test netty server/client to run SSL tests with * Simplify usage of test netty components and clean up * Improve exception handling in NettyTestHandler * Add openSSL test for X509UtilitiesTests * Expose engine for test usage * Add the X509 peer chain check from the socket based test * Port of TLSAuthenticationTests to use Netty so we can use different SSL providers, add boringSSL tests * Adapt tests to new config structure * Readd `useOpenSsl` configuration * Readd `useOpenSsl` configuration * Fix up ArtemisTransport for OpenSSL plus tests * Adapt auth tests * Formatting * Remove obsolte file * Fix config misnomer * Add SNI host logic to OpenSSL execution branch * Remove TLS_DHE_RSA tests * Make exception handling in the netty test infra deterministic --- .../config/FirewallConfigurationImpl.kt | 3 +- build.gradle | 9 +- node-api/build.gradle | 1 + .../nodeapi/internal/ArtemisTcpTransport.kt | 10 +- .../internal/bridging/AMQPBridgeManager.kt | 9 +- .../internal/config/SslConfiguration.kt | 7 +- .../protonwrapper/netty/AMQPChannelHandler.kt | 18 +- .../protonwrapper/netty/AMQPClient.kt | 9 +- .../protonwrapper/netty/AMQPConfiguration.kt | 7 + .../protonwrapper/netty/AMQPServer.kt | 9 +- .../netty/AliasProvidingKeyMangerWrapper.kt | 94 ++++++ .../CertHoldingKeyManagerFactoryWrapper.kt | 66 ++++ .../internal/protonwrapper/netty/SSLHelper.kt | 32 ++ .../netty/TrustManagerFactoryWrapper.kt | 36 ++ .../internal/crypto/X509UtilitiesTest.kt | 73 +++- .../netty/TestKeyManagerFactoryWrapper.kt | 95 ++++++ .../netty/TestTrustManagerFactoryWrapper.kt | 51 +++ node/build.gradle | 4 + .../net/corda/node/amqp/AMQPBridgeTest.kt | 68 ++-- .../net/corda/node/amqp/ProtonWrapperTests.kt | 32 +- .../node/services/config/NodeConfiguration.kt | 5 +- node/src/main/resources/reference.conf | 1 + .../NettyEngineBasedTlsAuthenticationTests.kt | 317 ++++++++++++++++++ .../corda/testing/internal/NettyTestClient.kt | 99 ++++++ .../testing/internal/NettyTestHandler.kt | 74 ++++ .../corda/testing/internal/NettyTestServer.kt | 97 ++++++ .../internal/stubs/CertificateStoreStubs.kt | 4 +- .../testing/internal/TestNettyTestInfra.kt | 88 +++++ 28 files changed, 1267 insertions(+), 51 deletions(-) create mode 100644 node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AliasProvidingKeyMangerWrapper.kt create mode 100644 node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/CertHoldingKeyManagerFactoryWrapper.kt create mode 100644 node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/TrustManagerFactoryWrapper.kt create mode 100644 node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/TestKeyManagerFactoryWrapper.kt create mode 100644 node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/TestTrustManagerFactoryWrapper.kt create mode 100644 node/src/test/kotlin/net/corda/node/utilities/NettyEngineBasedTlsAuthenticationTests.kt create mode 100644 testing/test-utils/src/main/kotlin/net/corda/testing/internal/NettyTestClient.kt create mode 100644 testing/test-utils/src/main/kotlin/net/corda/testing/internal/NettyTestHandler.kt create mode 100644 testing/test-utils/src/main/kotlin/net/corda/testing/internal/NettyTestServer.kt create mode 100644 testing/test-utils/src/test/kotlin/net/corda/testing/internal/TestNettyTestInfra.kt diff --git a/bridge/src/main/kotlin/net/corda/bridge/services/config/FirewallConfigurationImpl.kt b/bridge/src/main/kotlin/net/corda/bridge/services/config/FirewallConfigurationImpl.kt index 980c145bc6..804624f8f5 100644 --- a/bridge/src/main/kotlin/net/corda/bridge/services/config/FirewallConfigurationImpl.kt +++ b/bridge/src/main/kotlin/net/corda/bridge/services/config/FirewallConfigurationImpl.kt @@ -19,7 +19,8 @@ data class BridgeSSLConfigurationImpl(private val sslKeystore: Path, private val keyStorePassword: String, private val trustStoreFile: Path, private val trustStorePassword: String, - private val crlCheckSoftFail: Boolean) : BridgeSSLConfiguration { + private val crlCheckSoftFail: Boolean, + override val useOpenSsl: Boolean = false) : BridgeSSLConfiguration { override val keyStore = FileBasedCertificateStoreSupplier(sslKeystore, keyStorePassword) override val trustStore = FileBasedCertificateStoreSupplier(trustStoreFile, trustStorePassword) diff --git a/build.gradle b/build.gradle index 8ad9cc6b53..3291b49e7d 100644 --- a/build.gradle +++ b/build.gradle @@ -37,7 +37,8 @@ buildscript { ext.metrics_version = constants.getProperty("metricsVersion") ext.metrics_new_relic_version = constants.getProperty("metricsNewRelicVersion") ext.okhttp_version = '3.5.0' - ext.netty_version = '4.1.22.Final' + ext.netty_version = '4.1.29.Final' + ext.tcnative_version = '2.0.14.Final' ext.typesafe_config_version = constants.getProperty("typesafeConfigVersion") ext.fileupload_version = '1.3.3' ext.junit_version = '4.12' @@ -264,7 +265,11 @@ allprojects { // Demand that everything uses our given version of Netty. eachDependency { details -> if (details.requested.group == 'io.netty' && details.requested.name.startsWith('netty-')) { - details.useVersion netty_version + if (details.requested.name.startsWith('netty-tcnative')){ + details.useVersion tcnative_version + } else { + details.useVersion netty_version + } } } } diff --git a/node-api/build.gradle b/node-api/build.gradle index 8bd131c237..f8b422f4a7 100644 --- a/node-api/build.gradle +++ b/node-api/build.gradle @@ -57,6 +57,7 @@ dependencies { testCompile "org.assertj:assertj-core:$assertj_version" testCompile "org.jetbrains.kotlin:kotlin-test:$kotlin_version" testCompile project(':node-driver') + testCompile project(':test-utils') compile ("org.apache.activemq:artemis-amqp-protocol:${artemis_version}") { // Gains our proton-j version from core module. diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt index c4c184342d..0b484e4de4 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt @@ -100,30 +100,32 @@ class ArtemisTcpTransport { fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true): TransportConfiguration { - return p2pAcceptorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL) + return p2pAcceptorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL, useOpenSsl = config?.useOpenSsl ?: false) } fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true): TransportConfiguration { - return p2pConnectorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL) + return p2pConnectorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL, useOpenSsl = config?.useOpenSsl ?: false) } - fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true): TransportConfiguration { + fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true, useOpenSsl: Boolean = false): TransportConfiguration { val options = defaultArtemisOptions(hostAndPort).toMutableMap() if (enableSSL) { options.putAll(defaultSSLOptions) (keyStore to trustStore).addToTransportOptions(options) + options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER } return TransportConfiguration(acceptorFactoryClassName, options) } - fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true): TransportConfiguration { + fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true, useOpenSsl: Boolean = false): TransportConfiguration { val options = defaultArtemisOptions(hostAndPort).toMutableMap() if (enableSSL) { options.putAll(defaultSSLOptions) (keyStore to trustStore).addToTransportOptions(options) + options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER } return TransportConfiguration(connectorFactoryClassName, options) } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt index bd19eff603..c55c40d212 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt @@ -43,8 +43,13 @@ class AMQPBridgeManager(config: MutualSslConfiguration, socksProxyConfig: SocksP private class AMQPConfigurationImpl private constructor(override val keyStore: CertificateStore, override val trustStore: CertificateStore, override val socksProxyConfig: SocksProxyConfig?, - override val maxMessageSize: Int) : AMQPConfiguration { - constructor(config: MutualSslConfiguration, socksProxyConfig: SocksProxyConfig?, maxMessageSize: Int) : this(config.keyStore.get(), config.trustStore.get(), socksProxyConfig, maxMessageSize) + override val maxMessageSize: Int, + override val useOpenSsl: Boolean) : AMQPConfiguration { + constructor(config: MutualSslConfiguration, socksProxyConfig: SocksProxyConfig?, maxMessageSize: Int) : this(config.keyStore.get(), + config.trustStore.get(), + socksProxyConfig, + maxMessageSize, + config.useOpenSsl) } private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(config, socksProxyConfig, maxMessageSize) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/config/SslConfiguration.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/config/SslConfiguration.kt index d8349bc80e..dad42d37ca 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/config/SslConfiguration.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/config/SslConfiguration.kt @@ -4,12 +4,13 @@ interface SslConfiguration { val keyStore: FileBasedCertificateStoreSupplier? val trustStore: FileBasedCertificateStoreSupplier? + val useOpenSsl: Boolean companion object { - fun mutual(keyStore: FileBasedCertificateStoreSupplier, trustStore: FileBasedCertificateStoreSupplier): MutualSslConfiguration { + fun mutual(keyStore: FileBasedCertificateStoreSupplier, trustStore: FileBasedCertificateStoreSupplier, useOpenSsl: Boolean = false ): MutualSslConfiguration { - return MutualSslOptions(keyStore, trustStore) + return MutualSslOptions(keyStore, trustStore, useOpenSsl) } } } @@ -20,4 +21,4 @@ interface MutualSslConfiguration : SslConfiguration { override val trustStore: FileBasedCertificateStoreSupplier } -private class MutualSslOptions(override val keyStore: FileBasedCertificateStoreSupplier, override val trustStore: FileBasedCertificateStoreSupplier) : MutualSslConfiguration \ No newline at end of file +private class MutualSslOptions(override val keyStore: FileBasedCertificateStoreSupplier, override val trustStore: FileBasedCertificateStoreSupplier, override val useOpenSsl: Boolean ) : MutualSslConfiguration \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt index d46b9a4df1..387537fd01 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt @@ -34,6 +34,7 @@ import javax.net.ssl.SSLException */ internal class AMQPChannelHandler(private val serverMode: Boolean, private val allowedRemoteLegalNames: Set?, + private var keyManagerFactory: CertHoldingKeyManagerFactoryWrapper, private val userName: String?, private val password: String?, private val trace: Boolean, @@ -45,11 +46,11 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, } private lateinit var remoteAddress: InetSocketAddress - private var localCert: X509Certificate? = null private var remoteCert: X509Certificate? = null private var eventProcessor: EventProcessor? = null private var suppressClose: Boolean = false private var badCert: Boolean = false + private var localCert: X509Certificate? = null private fun withMDC(block: () -> Unit) { val oldMDC = MDC.getCopyOfContextMap() @@ -122,7 +123,18 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, if (evt is SslHandshakeCompletionEvent) { if (evt.isSuccess) { val sslHandler = ctx.pipeline().get(SslHandler::class.java) - localCert = sslHandler.engine().session.localCertificates[0].x509 + val sslSession = sslHandler.engine().session + localCert = keyManagerFactory.getCurrentCertChain()?.get(0) + if (localCert == null) { + log.error("SSL KeyManagerFactory failed to provide a local cert") + ctx.close() + return + } + if (sslSession.peerCertificates == null || sslSession.peerCertificates.isEmpty()) { + log.error("No peer certificates") + ctx.close() + return + } remoteCert = sslHandler.engine().session.peerCertificates[0].x509 val remoteX500Name = try { CordaX500Name.build(remoteCert!!.subjectX500Principal) @@ -151,7 +163,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, } else { badCert = true } - logErrorWithMDC("Handshake failure ${evt.cause().message}") + logErrorWithMDC("Handshake failure: ${evt.cause().message}") if (log.isTraceEnabled) { withMDC { log.trace("Handshake failure", evt.cause()) } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt index d211918168..e1ca000f47 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt @@ -22,6 +22,7 @@ import rx.Observable import rx.subjects.PublishSubject import java.lang.Long.min import java.net.InetSocketAddress +import java.security.cert.X509Certificate import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock import javax.net.ssl.KeyManagerFactory @@ -157,12 +158,18 @@ class AMQPClient(val targets: List, } } + val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory) val target = parent.currentTarget - val handler = createClientSslHelper(target, parent.allowedRemoteLegalNames, keyManagerFactory, trustManagerFactory) + val handler = if (parent.configuration.useOpenSsl){ + createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc()) + } else { + createClientSslHelper(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory) + } pipeline.addLast("sslHandler", handler) if (conf.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO)) pipeline.addLast(AMQPChannelHandler(false, parent.allowedRemoteLegalNames, + wrappedKeyManagerFactory, conf.userName, conf.password, conf.trace, diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPConfiguration.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPConfiguration.kt index d39c5b4647..4f676ab2f3 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPConfiguration.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPConfiguration.kt @@ -55,5 +55,12 @@ interface AMQPConfiguration { @JvmDefault val socksProxyConfig: SocksProxyConfig? get() = null + + /** + * Whether to use the tcnative open/boring SSL provider or the default Java SSL provider + */ + @JvmDefault + val useOpenSsl: Boolean + get() = false } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt index d65b77139c..2f8cff04ef 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt @@ -23,6 +23,7 @@ import rx.Observable import rx.subjects.PublishSubject import java.net.BindException import java.net.InetSocketAddress +import java.security.cert.X509Certificate import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.locks.ReentrantLock import javax.net.ssl.KeyManagerFactory @@ -66,11 +67,17 @@ class AMQPServer(val hostName: String, override fun initChannel(ch: SocketChannel) { val pipeline = ch.pipeline() - val handler = createServerSslHelper(keyManagerFactory, trustManagerFactory) + val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory) + val handler = if (parent.configuration.useOpenSsl){ + createServerOpenSslHandler(wrappedKeyManagerFactory, trustManagerFactory, ch.alloc()) + } else { + createServerSslHelper(wrappedKeyManagerFactory, trustManagerFactory) + } pipeline.addLast("sslHandler", handler) if (conf.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO)) pipeline.addLast(AMQPChannelHandler(true, null, + wrappedKeyManagerFactory, conf.userName, conf.password, conf.trace, diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AliasProvidingKeyMangerWrapper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AliasProvidingKeyMangerWrapper.kt new file mode 100644 index 0000000000..aaa36c3f3f --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AliasProvidingKeyMangerWrapper.kt @@ -0,0 +1,94 @@ +package net.corda.nodeapi.internal.protonwrapper.netty + +import java.net.Socket +import java.security.Principal +import java.security.PrivateKey +import java.security.cert.X509Certificate +import javax.net.ssl.SSLEngine +import javax.net.ssl.X509ExtendedKeyManager +import javax.net.ssl.X509KeyManager + +interface AliasProvidingKeyMangerWrapper : X509KeyManager { + var lastAlias: String? +} + + +class AliasProvidingKeyMangerWrapperImpl(private val keyManager: X509KeyManager) : AliasProvidingKeyMangerWrapper { + override var lastAlias: String? = null + + override fun getClientAliases(p0: String?, p1: Array?): Array { + return keyManager.getClientAliases(p0, p1) + } + + override fun getServerAliases(p0: String?, p1: Array?): Array { + return getServerAliases(p0, p1) + } + + override fun chooseServerAlias(p0: String?, p1: Array?, p2: Socket?): String? { + return storeIfNotNull { keyManager.chooseServerAlias(p0, p1, p2) } + } + + override fun getCertificateChain(p0: String?): Array { + return keyManager.getCertificateChain(p0) + } + + override fun getPrivateKey(p0: String?): PrivateKey { + return keyManager.getPrivateKey(p0) + } + + override fun chooseClientAlias(p0: Array?, p1: Array?, p2: Socket?): String? { + return storeIfNotNull { keyManager.chooseClientAlias(p0, p1, p2) } + } + + private fun storeIfNotNull(func: () -> String?): String? { + val alias = func() + if (alias != null) { + lastAlias = alias + } + return alias + } +} + +class AliasProvidingExtendedKeyMangerWrapper(private val keyManager: X509ExtendedKeyManager) : X509ExtendedKeyManager(), AliasProvidingKeyMangerWrapper { + override var lastAlias: String? = null + + override fun getClientAliases(p0: String?, p1: Array?): Array { + return keyManager.getClientAliases(p0, p1) + } + + override fun getServerAliases(p0: String?, p1: Array?): Array { + return keyManager.getServerAliases(p0, p1) + } + + override fun chooseServerAlias(p0: String?, p1: Array?, p2: Socket?): String? { + return storeIfNotNull { keyManager.chooseServerAlias(p0, p1, p2) } + } + + override fun getCertificateChain(p0: String?): Array { + return keyManager.getCertificateChain(p0) + } + + override fun getPrivateKey(p0: String?): PrivateKey { + return keyManager.getPrivateKey(p0) + } + + override fun chooseClientAlias(p0: Array?, p1: Array?, p2: Socket?): String? { + return storeIfNotNull { keyManager.chooseClientAlias(p0, p1, p2) } + } + + override fun chooseEngineClientAlias(p0: Array?, p1: Array?, p2: SSLEngine?): String? { + return storeIfNotNull { keyManager.chooseEngineClientAlias(p0, p1, p2) } + } + + override fun chooseEngineServerAlias(p0: String?, p1: Array?, p2: SSLEngine?): String? { + return storeIfNotNull { keyManager.chooseEngineServerAlias(p0, p1, p2) } + } + + private fun storeIfNotNull(func: () -> String?): String? { + val alias = func() + if (alias != null) { + lastAlias = alias + } + return alias + } +} \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/CertHoldingKeyManagerFactoryWrapper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/CertHoldingKeyManagerFactoryWrapper.kt new file mode 100644 index 0000000000..102886fcc2 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/CertHoldingKeyManagerFactoryWrapper.kt @@ -0,0 +1,66 @@ +package net.corda.nodeapi.internal.protonwrapper.netty + +import java.security.KeyStore +import java.security.cert.X509Certificate +import javax.net.ssl.* + + +class CertHoldingKeyManagerFactorySpiWrapper(private val factorySpi: KeyManagerFactorySpi) : KeyManagerFactorySpi() { + override fun engineInit(p0: KeyStore?, p1: CharArray?) { + val engineInitMethod = KeyManagerFactorySpi::class.java.getDeclaredMethod("engineInit", KeyStore::class.java, CharArray::class.java) + engineInitMethod.isAccessible = true + engineInitMethod.invoke(factorySpi, p0, p1) + } + + override fun engineInit(p0: ManagerFactoryParameters?) { + val engineInitMethod = KeyManagerFactorySpi::class.java.getDeclaredMethod("engineInit", ManagerFactoryParameters::class.java) + engineInitMethod.isAccessible = true + engineInitMethod.invoke(factorySpi, p0) + } + + private fun getKeyManagersImpl(): Array { + val engineGetKeyManagersMethod = KeyManagerFactorySpi::class.java.getDeclaredMethod("engineGetKeyManagers") + engineGetKeyManagersMethod.isAccessible = true + @Suppress("UNCHECKED_CAST") + val keyManagers = engineGetKeyManagersMethod.invoke(factorySpi) as Array + return if (factorySpi is CertHoldingKeyManagerFactorySpiWrapper) keyManagers else keyManagers.mapNotNull { + @Suppress("USELESS_CAST") // the casts to KeyManager are not useless - without them, the typed array will be of type Any + when (it) { + is X509ExtendedKeyManager -> AliasProvidingExtendedKeyMangerWrapper(it) as KeyManager + is X509KeyManager -> AliasProvidingKeyMangerWrapperImpl(it) as KeyManager + else -> null + } + }.toTypedArray() + } + + private val keyManagers = lazy { getKeyManagersImpl() } + + override fun engineGetKeyManagers(): Array { + return keyManagers.value + } +} + +/** + * You can wrap a key manager factory in this class if you need to get the cert chain currently used to identify or + * verify. When using for TLS channels, make sure to wrap the (singleton) factory separately on each channel, as + * the wrapper is not thread safe as in it will return the last used alias/cert chain and has itself no notion + * of belonging to a certain channel. + */ +class CertHoldingKeyManagerFactoryWrapper(factory: KeyManagerFactory) : KeyManagerFactory(getFactorySpi(factory), factory.provider, factory.algorithm) { + companion object { + private fun getFactorySpi(factory: KeyManagerFactory): KeyManagerFactorySpi { + val spiField = KeyManagerFactory::class.java.getDeclaredField("factorySpi") + spiField.isAccessible = true + return CertHoldingKeyManagerFactorySpiWrapper(spiField.get(factory) as KeyManagerFactorySpi) + } + } + + fun getCurrentCertChain(): Array? { + val keyManager = keyManagers.firstOrNull() + val alias = if (keyManager is AliasProvidingKeyMangerWrapper) keyManager.lastAlias else null + return if (alias != null && keyManager is X509KeyManager) { + keyManager.getCertificateChain(alias) + } else null + } + +} \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt index 9850e6a8ce..932a28f6c6 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt @@ -1,6 +1,9 @@ package net.corda.nodeapi.internal.protonwrapper.netty +import io.netty.buffer.ByteBufAllocator +import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.ssl.SslHandler +import io.netty.handler.ssl.SslProvider import net.corda.core.crypto.SecureHash import net.corda.core.crypto.newSecureRandom import net.corda.core.identity.CordaX500Name @@ -126,6 +129,23 @@ internal fun createClientSslHelper(target: NetworkHostAndPort, return SslHandler(sslEngine) } +internal fun createClientOpenSslHandler(target: NetworkHostAndPort, + expectedRemoteLegalNames: Set, + keyManagerFactory: KeyManagerFactory, + trustManagerFactory: TrustManagerFactory, + alloc: ByteBufAllocator): SslHandler { + val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build() + val sslEngine = sslContext.newEngine(alloc, target.host, target.port) + sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray() + sslEngine.enabledCipherSuites = ArtemisTcpTransport.CIPHER_SUITES.toTypedArray() + if (expectedRemoteLegalNames.size == 1) { + val sslParameters = sslEngine.sslParameters + sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) + sslEngine.sslParameters = sslParameters + } + return SslHandler(sslEngine) +} + internal fun createServerSslHelper(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslHandler { val sslContext = SSLContext.getInstance("TLS") @@ -159,6 +179,18 @@ internal fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateSto return CertPathTrustManagerParameters(pkixParams) } +internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory, + trustManagerFactory: TrustManagerFactory, + alloc: ByteBufAllocator): SslHandler { + val sslContext = SslContextBuilder.forServer(keyManagerFactory).sslProvider(SslProvider.OPENSSL).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build() + val sslEngine = sslContext.newEngine(alloc) + sslEngine.useClientMode = false + sslEngine.needClientAuth = true + sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray() + sslEngine.enabledCipherSuites = ArtemisTcpTransport.CIPHER_SUITES.toTypedArray() + return SslHandler(sslEngine) +} + fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.password.toCharArray()) fun TrustManagerFactory.init(trustStore: CertificateStore) = init(trustStore.value.internal) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/TrustManagerFactoryWrapper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/TrustManagerFactoryWrapper.kt new file mode 100644 index 0000000000..dd3a318cae --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/TrustManagerFactoryWrapper.kt @@ -0,0 +1,36 @@ +package net.corda.nodeapi.internal.protonwrapper.netty + +import java.security.KeyStore +import javax.net.ssl.* + +class LoggingTrustManagerFactorySpiWrapper(private val factorySpi: TrustManagerFactorySpi) : TrustManagerFactorySpi() { + override fun engineGetTrustManagers(): Array { + val engineGetTrustManagersMethod = TrustManagerFactorySpi::class.java.getDeclaredMethod("engineGetTrustManagers") + engineGetTrustManagersMethod.isAccessible = true + @Suppress("UNCHECKED_CAST") + val trustManagers = engineGetTrustManagersMethod.invoke(factorySpi) as Array + return if (factorySpi is LoggingTrustManagerFactorySpiWrapper) trustManagers else trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java).map { LoggingTrustManagerWrapper(it) }.toTypedArray() + } + + override fun engineInit(p0: KeyStore?) { + val engineInitMethod = TrustManagerFactorySpi::class.java.getDeclaredMethod("engineInit", KeyStore::class.java) + engineInitMethod.isAccessible = true + engineInitMethod.invoke(factorySpi, p0) + } + + override fun engineInit(p0: ManagerFactoryParameters?) { + val engineInitMethod = TrustManagerFactorySpi::class.java.getDeclaredMethod("engineInit", ManagerFactoryParameters::class.java) + engineInitMethod.isAccessible = true + engineInitMethod.invoke(factorySpi, p0) + } +} + +class LoggingTrustManagerFactoryWrapper(factory: TrustManagerFactory) : TrustManagerFactory(getFactorySpi(factory), factory.provider, factory.algorithm) { + companion object { + private fun getFactorySpi(factory: TrustManagerFactory): TrustManagerFactorySpi { + val spiField = TrustManagerFactory::class.java.getDeclaredField("factorySpi") + spiField.isAccessible = true + return LoggingTrustManagerFactorySpiWrapper(spiField.get(factory) as TrustManagerFactorySpi) + } + } +} \ No newline at end of file diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/X509UtilitiesTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/X509UtilitiesTest.kt index 3db9072041..0a3d770d33 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/X509UtilitiesTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/X509UtilitiesTest.kt @@ -1,5 +1,10 @@ package net.corda.nodeapi.internal.crypto + +import io.netty.handler.ssl.ClientAuth +import io.netty.handler.ssl.SslContextBuilder +import io.netty.handler.ssl.SslProvider +import net.corda.core.crypto.Crypto import net.corda.core.crypto.* import net.corda.core.crypto.Crypto.COMPOSITE_KEY import net.corda.core.crypto.Crypto.ECDSA_SECP256K1_SHA256 @@ -28,8 +33,12 @@ import net.corda.serialization.internal.amqp.amqpMagic import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.core.TestIdentity -import net.corda.testing.internal.stubs.CertificateStoreStubs +import net.corda.testing.driver.PortAllocation +import net.corda.testing.internal.NettyTestClient +import net.corda.testing.internal.NettyTestHandler +import net.corda.testing.internal.NettyTestServer import net.corda.testing.internal.createDevIntermediateCaCertPath +import net.corda.testing.internal.stubs.CertificateStoreStubs import net.i2p.crypto.eddsa.EdDSAPrivateKey import org.assertj.core.api.Assertions.assertThat import org.bouncycastle.asn1.x509.* @@ -65,6 +74,9 @@ class X509UtilitiesTest { "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" ) + + val portAllocation = PortAllocation.Incremental(10000) + // We ensure that all of the algorithms are both used (at least once) as first and second in the following [Pair]s. // We also add [DEFAULT_TLS_SIGNATURE_SCHEME] and [DEFAULT_IDENTITY_SIGNATURE_SCHEME] combinations for consistency. val certChainSchemeCombinations = listOf( @@ -348,6 +360,65 @@ class X509UtilitiesTest { assertTrue(done) } + @Test + fun `create server cert and use in OpenSSL channel`() { + val sslConfig = CertificateStoreStubs.P2P.withCertificatesDirectory(tempFolder.root.toPath(), keyStorePassword = "serverstorepass") + + val (rootCa, intermediateCa) = createDevIntermediateCaCertPath() + + // Generate server cert and private key and populate another keystore suitable for SSL + sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) + sslConfig.createTrustStore(rootCa.certificate) + + val keyStore = sslConfig.keyStore.get() + val trustStore = sslConfig.trustStore.get() + + val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) + keyManagerFactory.init(keyStore) + + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + trustManagerFactory.init(trustStore) + + + val sslServerContext = SslContextBuilder + .forServer(keyManagerFactory) + .trustManager(trustManagerFactory) + .clientAuth(ClientAuth.REQUIRE) + .ciphers(CIPHER_SUITES.toMutableList()) + .sslProvider(SslProvider.OPENSSL) + .protocols("TLSv1.2") + .build() + val sslClientContext = SslContextBuilder + .forClient() + .keyManager(keyManagerFactory) + .trustManager(trustManagerFactory) + .ciphers(CIPHER_SUITES.toMutableList()) + .sslProvider(SslProvider.OPENSSL) + .protocols("TLSv1.2") + .build() + val serverHandler = NettyTestHandler { ctx, msg -> ctx?.writeAndFlush(msg) } + val clientHandler = NettyTestHandler { _, msg -> assertEquals("Hello", NettyTestHandler.readString(msg)) } + NettyTestServer(sslServerContext, serverHandler, portAllocation.nextPort()).use { server -> + server.start() + NettyTestClient(sslClientContext, InetAddress.getLocalHost().canonicalHostName, server.port, clientHandler).use { client -> + client.start() + + clientHandler.writeString("Hello") + val readCalled = clientHandler.waitForReadCalled() + clientHandler.rethrowIfFailed() + serverHandler.rethrowIfFailed() + assertTrue(readCalled) + assertEquals(1, serverHandler.readCalledCounter) + assertEquals(1, clientHandler.readCalledCounter) + + val peerChain = client.engine!!.session.peerCertificates.x509 + val peerX500Principal = peerChain[0].subjectX500Principal + assertEquals(MEGA_CORP.name.x500Principal, peerX500Principal) + X509Utilities.validateCertificateChain(rootCa.certificate, peerChain) + } + } + } + private fun tempFile(name: String): Path = tempFolder.root.toPath() / name private fun MutualSslConfiguration.createTrustStore(rootCert: X509Certificate) { diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/TestKeyManagerFactoryWrapper.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/TestKeyManagerFactoryWrapper.kt new file mode 100644 index 0000000000..8ca20dad62 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/TestKeyManagerFactoryWrapper.kt @@ -0,0 +1,95 @@ +package net.corda.nodeapi.internal.protonwrapper.netty + +import com.nhaarman.mockito_kotlin.doReturn +import com.nhaarman.mockito_kotlin.whenever +import net.corda.core.internal.div +import net.corda.node.services.config.NodeConfiguration +import net.corda.node.services.config.configureWithDevSSLCertificate +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.internal.rigorousMock +import net.corda.testing.internal.stubs.CertificateStoreStubs +import org.junit.Rule +import org.junit.Test +import org.junit.rules.TemporaryFolder +import javax.net.ssl.KeyManagerFactory +import javax.net.ssl.X509KeyManager +import kotlin.test.* + +class TestKeyManagerFactoryWrapper { + + @Rule + @JvmField + val temporaryFolder = TemporaryFolder() + + private abstract class AbstractNodeConfiguration : NodeConfiguration + + + @Test + fun testWrapping() { + val baseDir = temporaryFolder.root.toPath() / "testWrapping" + val certDir = baseDir / "certificates" + val sslConfig = CertificateStoreStubs.P2P.withCertificatesDirectory(temporaryFolder.root.toPath(), keyStorePassword = "serverstorepass") + val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(temporaryFolder.root.toPath()) + + val config = rigorousMock().also { + doReturn(baseDir).whenever(it).baseDirectory + doReturn(certDir).whenever(it).certificatesDirectory + doReturn(ALICE_NAME).whenever(it).myLegalName + doReturn(sslConfig).whenever(it).p2pSslOptions + doReturn(signingCertificateStore).whenever(it).signingCertificateStore + } + config.configureWithDevSSLCertificate() + + val underlyingKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) + + val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(underlyingKeyManagerFactory) + wrappedKeyManagerFactory.init(config.p2pSslOptions.keyStore.get()) + val keyManagers = wrappedKeyManagerFactory.keyManagers + assertFalse(keyManagers.isEmpty()) + assertNull(wrappedKeyManagerFactory.getCurrentCertChain()) + val keyManager = keyManagers.first() as X509KeyManager + val alias = keyManager.chooseClientAlias(arrayOf("EC_EC"), null, null) + assertNotNull(alias) + val certChain = wrappedKeyManagerFactory.getCurrentCertChain() + assertNotNull(certChain) + assertTrue(certChain!!.isNotEmpty()) + + assertEquals(alias, (keyManager as AliasProvidingKeyMangerWrapper).lastAlias) + } + + @Test + fun testWrappingSeparately() { + val baseDir = temporaryFolder.root.toPath() / "testWrapping" + val certDir = baseDir / "certificates" + val sslConfig = CertificateStoreStubs.P2P.withCertificatesDirectory(temporaryFolder.root.toPath(), keyStorePassword = "serverstorepass") + val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(temporaryFolder.root.toPath()) + + val config = rigorousMock().also { + doReturn(baseDir).whenever(it).baseDirectory + doReturn(certDir).whenever(it).certificatesDirectory + doReturn(ALICE_NAME).whenever(it).myLegalName + doReturn(sslConfig).whenever(it).p2pSslOptions + doReturn(signingCertificateStore).whenever(it).signingCertificateStore + } + config.configureWithDevSSLCertificate() + + val underlyingKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) + + val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(underlyingKeyManagerFactory) + wrappedKeyManagerFactory.init(config.p2pSslOptions.keyStore.get()) + + val otherWrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(underlyingKeyManagerFactory) + + val keyManagers = wrappedKeyManagerFactory.keyManagers + assertFalse(keyManagers.isEmpty()) + assertNull(wrappedKeyManagerFactory.getCurrentCertChain()) + val keyManager = keyManagers.first() as X509KeyManager + keyManager.chooseClientAlias(arrayOf("EC_EC"), null, null) + val certChain = wrappedKeyManagerFactory.getCurrentCertChain() + assertNotNull(certChain) + assertTrue(certChain!!.isNotEmpty()) + + assertNull(otherWrappedKeyManagerFactory.getCurrentCertChain()) + } + +} \ No newline at end of file diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/TestTrustManagerFactoryWrapper.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/TestTrustManagerFactoryWrapper.kt new file mode 100644 index 0000000000..4da8b78796 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/TestTrustManagerFactoryWrapper.kt @@ -0,0 +1,51 @@ +package net.corda.nodeapi.internal.protonwrapper.netty + +import com.nhaarman.mockito_kotlin.doReturn +import com.nhaarman.mockito_kotlin.whenever +import net.corda.core.internal.div +import net.corda.node.services.config.NodeConfiguration +import net.corda.node.services.config.configureWithDevSSLCertificate +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.internal.rigorousMock +import net.corda.testing.internal.stubs.CertificateStoreStubs +import org.junit.Assert.assertTrue +import org.junit.Rule +import org.junit.Test +import org.junit.rules.TemporaryFolder +import javax.net.ssl.TrustManagerFactory + +class TestTrustManagerFactoryWrapper { + + @Rule + @JvmField + val temporaryFolder = TemporaryFolder() + + private abstract class AbstractNodeConfiguration : NodeConfiguration + + + @Test + fun testWrapping() { + val baseDir = temporaryFolder.root.toPath() / "testWrapping" + val certDir = baseDir / "certificates" + val sslConfig = CertificateStoreStubs.P2P.withCertificatesDirectory(temporaryFolder.root.toPath(), keyStorePassword = "serverstorepass") + val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(temporaryFolder.root.toPath()) + + val config = rigorousMock().also { + doReturn(baseDir).whenever(it).baseDirectory + doReturn(certDir).whenever(it).certificatesDirectory + doReturn(ALICE_NAME).whenever(it).myLegalName + doReturn(sslConfig).whenever(it).p2pSslOptions + doReturn(signingCertificateStore).whenever(it).signingCertificateStore + } + config.configureWithDevSSLCertificate() + + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + + val wrapped = LoggingTrustManagerFactoryWrapper(trustManagerFactory) + wrapped.init(initialiseTrustStoreAndEnableCrlChecking(config.p2pSslOptions.trustStore.get(), false)) + + val trustManagers = wrapped.trustManagers + assertTrue(trustManagers.size > 0) + assertTrue(trustManagers[0] is LoggingTrustManagerWrapper) + } +} \ No newline at end of file diff --git a/node/build.gradle b/node/build.gradle index 6fb0a3945d..c28cd1ef8f 100644 --- a/node/build.gradle +++ b/node/build.gradle @@ -223,8 +223,12 @@ dependencies { testCompile("io.netty:netty-example:$netty_version") { exclude group: "io.netty", module: "netty-tcnative" exclude group: "ch.qos.logback", module: "logback-classic" + } + // Adding native SSL library to allow using native SSL with Artemis and AMQP + compile "io.netty:netty-tcnative-boringssl-static:$tcnative_version" + testCompile(project(':test-cli')) } diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPBridgeTest.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPBridgeTest.kt index 8252136a8d..4499bebfe6 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPBridgeTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPBridgeTest.kt @@ -5,11 +5,9 @@ import com.nhaarman.mockito_kotlin.whenever import net.corda.core.crypto.toStringShort import net.corda.core.internal.div import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.hours import net.corda.core.utilities.loggerFor -import net.corda.node.services.config.EnterpriseConfiguration -import net.corda.node.services.config.MutualExclusionConfiguration -import net.corda.node.services.config.NodeConfiguration -import net.corda.node.services.config.configureWithDevSSLCertificate +import net.corda.node.services.config.* import net.corda.node.services.messaging.ArtemisMessagingServer import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingComponent @@ -23,8 +21,8 @@ import net.corda.testing.core.BOB_NAME import net.corda.testing.core.MAX_MESSAGE_SIZE import net.corda.testing.core.TestIdentity import net.corda.testing.driver.PortAllocation -import net.corda.testing.internal.stubs.CertificateStoreStubs import net.corda.testing.internal.rigorousMock +import net.corda.testing.internal.stubs.CertificateStoreStubs import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.SimpleString @@ -34,12 +32,22 @@ import org.junit.Ignore import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder +import org.junit.runner.RunWith +import org.junit.runners.Parameterized import java.util.* +import kotlin.concurrent.thread import kotlin.system.measureNanoTime import kotlin.system.measureTimeMillis import kotlin.test.assertEquals -class AMQPBridgeTest { +@RunWith(Parameterized::class) +class AMQPBridgeTest(private val useOpenSsl: Boolean) { + companion object { + @JvmStatic + @Parameterized.Parameters(name = "useOpenSsl = {0}") + fun data(): Collection = listOf(false, true) + } + @Rule @JvmField val temporaryFolder = TemporaryFolder() @@ -198,13 +206,30 @@ class AMQPBridgeTest { var timeNanosCreateMessage = 0L var timeNanosSendMessage = 0L var timeMillisRead = 0L + + val recThread = thread { + val current = artemisConsumer.receive() + val messageId = current.getIntProperty(P2PMessagingHeaders.senderUUID) + assertEquals(numReceived, messageId) + ++numReceived + current.acknowledge() + timeMillisRead = measureTimeMillis { + while (numReceived < numMessages) { + val currentMsg = artemisConsumer.receive() + val loopMessageId = currentMsg.getIntProperty(P2PMessagingHeaders.senderUUID) + assertEquals(numReceived, loopMessageId) + ++numReceived + currentMsg.acknowledge() + } + } + } val simpleSourceQueueName = SimpleString(sourceQueueName) val totalTimeMillis = measureTimeMillis { - repeat(numMessages) { + repeat(numMessages) { i -> var artemisMessage: ClientMessage? = null timeNanosCreateMessage += measureNanoTime { artemisMessage = artemis.session.createMessage(true).apply { - putIntProperty("CountProp", it) + putIntProperty(P2PMessagingHeaders.senderUUID, i) writeBodyBufferBytes(rubbishPayload) // Use the magic deduplication property built into Artemis as our message identity too putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString())) @@ -215,18 +240,9 @@ class AMQPBridgeTest { } } artemisClient.started!!.session.commit() - - - timeMillisRead = measureTimeMillis { - while (numReceived < numMessages) { - val current = artemisConsumer.receive() - val messageId = current.getIntProperty("CountProp") - assertEquals(numReceived, messageId) - ++numReceived - current.acknowledge() - } - } + recThread.join(1.hours.toMillis()) } + println("Creating $numMessages messages took ${timeNanosCreateMessage / (1000 * 1000)} milliseconds") println("Sending $numMessages messages took ${timeNanosSendMessage / (1000 * 1000)} milliseconds") println("Receiving $numMessages messages took $timeMillisRead milliseconds") @@ -244,7 +260,7 @@ class AMQPBridgeTest { private fun createArtemis(sourceQueueName: String?): Triple { val baseDir = temporaryFolder.root.toPath() / "artemis" val certificatesDirectory = baseDir / "certificates" - val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) + val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = useOpenSsl) val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) val artemisConfig = rigorousMock().also { doReturn(baseDir).whenever(it).baseDirectory @@ -260,7 +276,8 @@ class AMQPBridgeTest { artemisConfig.configureWithDevSSLCertificate() val artemisServer = ArtemisMessagingServer(artemisConfig, artemisAddress.copy(host = "0.0.0.0"), MAX_MESSAGE_SIZE) - val artemisClient = ArtemisMessagingClient(artemisConfig.p2pSslOptions, artemisAddress, MAX_MESSAGE_SIZE) + + val artemisClient = ArtemisMessagingClient(artemisConfig.p2pSslOptions, artemisAddress, MAX_MESSAGE_SIZE, confirmationWindowSize = artemisConfig.enterpriseConfiguration.tuning.p2pConfirmationWindowSize) artemisServer.start() artemisClient.start() @@ -279,7 +296,7 @@ class AMQPBridgeTest { private fun createArtemisReceiver(targetAdress: NetworkHostAndPort, workingDir: String): Pair { val baseDir = temporaryFolder.root.toPath() / workingDir val certificatesDirectory = baseDir / "certificates" - val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) + val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = useOpenSsl) val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) val artemisConfig = rigorousMock().also { doReturn(baseDir).whenever(it).baseDirectory @@ -288,7 +305,9 @@ class AMQPBridgeTest { doReturn(signingCertificateStore).whenever(it).signingCertificateStore doReturn(p2pSslConfiguration).whenever(it).p2pSslOptions doReturn(targetAdress).whenever(it).p2pAddress - doReturn("").whenever(it).jmxMonitoringHttpPort + doReturn(null).whenever(it).jmxMonitoringHttpPort + @Suppress("DEPRECATION") + doReturn(emptyList()).whenever(it).certificateChainCheckPolicies doReturn(EnterpriseConfiguration(MutualExclusionConfiguration(false, "", 20000, 40000))).whenever(it).enterpriseConfiguration } artemisConfig.configureWithDevSSLCertificate() @@ -304,7 +323,7 @@ class AMQPBridgeTest { private fun createAMQPServer(maxMessageSize: Int = MAX_MESSAGE_SIZE): AMQPServer { val baseDir = temporaryFolder.root.toPath() / "server" val certificatesDirectory = baseDir / "certificates" - val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) + val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = useOpenSsl) val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) val serverConfig = rigorousMock().also { doReturn(temporaryFolder.root.toPath() / "server").whenever(it).baseDirectory @@ -321,6 +340,7 @@ class AMQPBridgeTest { override val trustStore = serverConfig.p2pSslOptions.trustStore.get() override val trace: Boolean = true override val maxMessageSize: Int = maxMessageSize + override val useOpenSsl = serverConfig.p2pSslOptions.useOpenSsl } return AMQPServer("0.0.0.0", amqpAddress.port, diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt index 42e035b52f..5007928e50 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt @@ -40,13 +40,32 @@ import org.junit.Assert.assertArrayEquals import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder +import org.junit.runner.RunWith +import org.junit.runners.Parameterized import java.security.cert.X509Certificate import javax.net.ssl.* import kotlin.concurrent.thread import kotlin.test.assertEquals import kotlin.test.assertTrue -class ProtonWrapperTests { + +@RunWith(Parameterized::class) +class ProtonWrapperTests(val sslSetup: SslSetup) { + companion object { + data class SslSetup(val clientNative: Boolean, val serverNative: Boolean) { + override fun toString(): String = "Client: ${if (clientNative) "openSsl" else "javaSsl"} Server: ${if (serverNative) "openSsl" else "javaSsl"} " + } + + @JvmStatic + @Parameterized.Parameters(name = "{0}") + fun data(): Collection = listOf( + SslSetup(false, false), + SslSetup(true, false), + SslSetup(false, true), + SslSetup(true, true) + ) + } + @Rule @JvmField val temporaryFolder = TemporaryFolder() @@ -407,7 +426,7 @@ class ProtonWrapperTests { val baseDirectory = temporaryFolder.root.toPath() / "artemis" val certificatesDirectory = baseDirectory / "certificates" val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) - val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) + val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = sslSetup.serverNative) val artemisConfig = rigorousMock().also { doReturn(baseDirectory).whenever(it).baseDirectory doReturn(certificatesDirectory).whenever(it).certificatesDirectory @@ -432,7 +451,7 @@ class ProtonWrapperTests { val baseDirectory = temporaryFolder.root.toPath() / "client" val certificatesDirectory = baseDirectory / "certificates" val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) - val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) + val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = sslSetup.clientNative) val clientConfig = rigorousMock().also { doReturn(baseDirectory).whenever(it).baseDirectory doReturn(certificatesDirectory).whenever(it).certificatesDirectory @@ -450,6 +469,7 @@ class ProtonWrapperTests { override val trustStore = clientTruststore override val trace: Boolean = true override val maxMessageSize: Int = maxMessageSize + override val useOpenSsl: Boolean = sslSetup.clientNative } return AMQPClient( listOf(NetworkHostAndPort("localhost", serverPort), @@ -463,7 +483,7 @@ class ProtonWrapperTests { val baseDirectory = temporaryFolder.root.toPath() / "client_%$id" val certificatesDirectory = baseDirectory / "certificates" val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) - val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) + val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = sslSetup.clientNative) val clientConfig = rigorousMock().also { doReturn(baseDirectory).whenever(it).baseDirectory doReturn(certificatesDirectory).whenever(it).certificatesDirectory @@ -481,6 +501,7 @@ class ProtonWrapperTests { override val trustStore = clientTruststore override val trace: Boolean = true override val maxMessageSize: Int = maxMessageSize + override val useOpenSsl: Boolean = sslSetup.clientNative } return AMQPClient( listOf(NetworkHostAndPort("localhost", serverPort)), @@ -496,7 +517,7 @@ class ProtonWrapperTests { val baseDirectory = temporaryFolder.root.toPath() / "server" val certificatesDirectory = baseDirectory / "certificates" val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) - val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) + val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = sslSetup.serverNative) val serverConfig = rigorousMock().also { doReturn(baseDirectory).whenever(it).baseDirectory doReturn(certificatesDirectory).whenever(it).certificatesDirectory @@ -514,6 +535,7 @@ class ProtonWrapperTests { override val trustStore = serverTruststore override val trace: Boolean = true override val maxMessageSize: Int = maxMessageSize + override val useOpenSsl: Boolean = sslSetup.serverNative } return AMQPServer( "0.0.0.0", diff --git a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt index 41824395d7..3fe0ee2875 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt @@ -280,7 +280,8 @@ data class NodeConfigurationImpl( override val flowMonitorPeriodMillis: Duration = DEFAULT_FLOW_MONITOR_PERIOD_MILLIS, override val flowMonitorSuspensionLoggingThresholdMillis: Duration = DEFAULT_FLOW_MONITOR_SUSPENSION_LOGGING_THRESHOLD_MILLIS, override val cordappDirectories: List = listOf(baseDirectory / CORDAPPS_DIR_NAME_DEFAULT), - override val jmxReporterType: JmxReporterType? = JmxReporterType.JOLOKIA + override val jmxReporterType: JmxReporterType? = JmxReporterType.JOLOKIA, + private val useOpenSsl: Boolean = false ) : NodeConfiguration { companion object { private val logger = loggerFor() @@ -313,7 +314,7 @@ data class NodeConfigurationImpl( private val p2pKeyStore = FileBasedCertificateStoreSupplier(p2pKeystorePath, keyStorePassword) private val p2pTrustStoreFilePath: Path get() = certificatesDirectory / "truststore.jks" private val p2pTrustStore = FileBasedCertificateStoreSupplier(p2pTrustStoreFilePath, trustStorePassword) - override val p2pSslOptions: MutualSslConfiguration = SslConfiguration.mutual(p2pKeyStore, p2pTrustStore) + override val p2pSslOptions: MutualSslConfiguration = SslConfiguration.mutual(p2pKeyStore, p2pTrustStore, useOpenSsl) override val rpcOptions: NodeRpcOptions get() { diff --git a/node/src/main/resources/reference.conf b/node/src/main/resources/reference.conf index 14322b602c..2f3da68fb2 100644 --- a/node/src/main/resources/reference.conf +++ b/node/src/main/resources/reference.conf @@ -1,6 +1,7 @@ emailAddress = "admin@company.com" keyStorePassword = "cordacadevpass" trustStorePassword = "trustpass" +useOpenSsl = false crlCheckSoftFail = true lazyBridgeStart = true additionalP2PAddresses = [] diff --git a/node/src/test/kotlin/net/corda/node/utilities/NettyEngineBasedTlsAuthenticationTests.kt b/node/src/test/kotlin/net/corda/node/utilities/NettyEngineBasedTlsAuthenticationTests.kt new file mode 100644 index 0000000000..d9929f82c1 --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/utilities/NettyEngineBasedTlsAuthenticationTests.kt @@ -0,0 +1,317 @@ +package net.corda.node.utilities + +import io.netty.handler.ssl.ClientAuth +import io.netty.handler.ssl.SslContext +import io.netty.handler.ssl.SslContextBuilder +import io.netty.handler.ssl.SslProvider +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SignatureScheme +import net.corda.core.identity.CordaX500Name +import net.corda.core.internal.div +import net.corda.nodeapi.internal.crypto.* +import net.corda.testing.driver.PortAllocation +import net.corda.testing.internal.NettyTestClient +import net.corda.testing.internal.NettyTestHandler +import net.corda.testing.internal.NettyTestServer +import org.junit.Rule +import org.junit.Test +import org.junit.rules.TemporaryFolder +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import java.net.InetAddress +import java.nio.file.Path +import javax.net.ssl.KeyManagerFactory +import javax.net.ssl.TrustManagerFactory +import javax.security.auth.x500.X500Principal +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +@RunWith(Parameterized::class) +class NettyEngineBasedTlsAuthenticationTests(val sslSetup: SslSetup) { + + @Rule + @JvmField + val tempFolder: TemporaryFolder = TemporaryFolder() + + // Root CA. + private val ROOT_X500 = X500Principal("CN=Root_CA_1,O=R3CEV,L=London,C=GB") + // Intermediate CA. + private val INTERMEDIATE_X500 = X500Principal("CN=Intermediate_CA_1,O=R3CEV,L=London,C=GB") + // TLS server (server). + private val CLIENT_1_X500 = CordaX500Name(commonName = "Client_1", organisation = "R3CEV", locality = "London", country = "GB") + // TLS client (client). + private val CLIENT_2_X500 = CordaX500Name(commonName = "Client_2", organisation = "R3CEV", locality = "London", country = "GB") + // Password for keys and keystores. + private val PASSWORD = "dummypassword" + // Default supported TLS schemes for Corda nodes. + private val CORDA_TLS_CIPHER_SUITES = arrayOf( + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" + ) + + private fun tempFile(name: String): Path = tempFolder.root.toPath() / name + + companion object { + private val portAllocation = PortAllocation.Incremental(10000) + + data class SslSetup(val clientNative: Boolean, val serverNative: Boolean) { + override fun toString(): String = "Client: ${if (clientNative) "openSsl" else "javaSsl"} Server: ${if (serverNative) "openSsl" else "javaSsl"} " + } + + @JvmStatic + @Parameterized.Parameters(name = "{0}") + fun data(): Collection = listOf( + SslSetup(false, false), + SslSetup(true, false), + SslSetup(false, true), + SslSetup(true, true) + ) + } + + @Test + fun `All EC R1`() { + val (serverContext, clientContext) = buildContexts( + rootCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + intermediateCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + serverCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + serverTLSScheme = Crypto.ECDSA_SECP256R1_SHA256, + clientCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + clientTLSScheme = Crypto.ECDSA_SECP256R1_SHA256 + ) + + testConnect(serverContext, clientContext, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256") + } + + @Test + fun `All RSA`() { + val (serverContext, clientContext) = buildContexts( + rootCAScheme = Crypto.RSA_SHA256, + intermediateCAScheme = Crypto.RSA_SHA256, + serverCAScheme = Crypto.RSA_SHA256, + serverTLSScheme = Crypto.RSA_SHA256, + clientCAScheme = Crypto.RSA_SHA256, + clientTLSScheme = Crypto.RSA_SHA256 + ) + + testConnect(serverContext, clientContext, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256") + } + + // Server's public key type is the one selected if users use different key types (e.g RSA and EC R1). + @Test + fun `Server RSA - Client EC R1 - CAs all EC R1`() { + val (serverContext, clientContext) = buildContexts( + rootCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + intermediateCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + serverCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + serverTLSScheme = Crypto.RSA_SHA256, + clientCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + clientTLSScheme = Crypto.ECDSA_SECP256R1_SHA256 + ) + + testConnect(serverContext, clientContext, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256") // Server's key type is selected. + } + + @Test + fun `Server EC R1 - Client RSA - CAs all EC R1`() { + val (serverContext, clientContext) = buildContexts( + rootCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + intermediateCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + serverCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + serverTLSScheme = Crypto.ECDSA_SECP256R1_SHA256, + clientCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + clientTLSScheme = Crypto.RSA_SHA256 + ) + + testConnect(serverContext, clientContext, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256") // Server's key type is selected. + } + + @Test + fun `Server EC R1 - Client EC R1 - CAs all RSA`() { + val (serverContext, clientContext) = buildContexts( + rootCAScheme = Crypto.RSA_SHA256, + intermediateCAScheme = Crypto.RSA_SHA256, + serverCAScheme = Crypto.RSA_SHA256, + serverTLSScheme = Crypto.ECDSA_SECP256R1_SHA256, + clientCAScheme = Crypto.RSA_SHA256, + clientTLSScheme = Crypto.ECDSA_SECP256R1_SHA256 + ) + + testConnect(serverContext, clientContext, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256") + } + + @Test + fun `Server EC R1 - Client RSA - Mixed CAs`() { + val (serverContext, clientContext) = buildContexts( + rootCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + intermediateCAScheme = Crypto.RSA_SHA256, + serverCAScheme = Crypto.RSA_SHA256, + serverTLSScheme = Crypto.ECDSA_SECP256R1_SHA256, + clientCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + clientTLSScheme = Crypto.RSA_SHA256 + ) + + testConnect(serverContext, clientContext, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256") + } + + // According to RFC 5246 (TLS 1.2), section 7.4.1.2 ClientHello cipher_suites: + // This is a list of the cryptographic options supported by the client, with the client's first preference first. + // + // However, the server is still free to ignore this order and pick what it thinks is best, + // see https://security.stackexchange.com/questions/121608 for more information. + @Test + fun `TLS cipher suite order matters - implementation dependent`() { + val (serverContext, clientContext) = buildContexts( + rootCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + intermediateCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + serverCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + serverTLSScheme = Crypto.ECDSA_SECP256R1_SHA256, + clientCAScheme = Crypto.ECDSA_SECP256R1_SHA256, + clientTLSScheme = Crypto.ECDSA_SECP256R1_SHA256, + cipherSuitesServer = arrayOf("TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256"), // GCM then CBC. + cipherSuitesClient = arrayOf("TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256") // CBC then GCM. + + ) + + val expectedCipherSuite = if (sslSetup.clientNative || sslSetup.serverNative) + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" // server wins if boring ssl is involved + else + "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256" // client wins in pure JRE SSL + testConnect(serverContext, clientContext, expectedCipherSuite) + } + + private fun buildContexts( + rootCAScheme: SignatureScheme, + intermediateCAScheme: SignatureScheme, + serverCAScheme: SignatureScheme, + serverTLSScheme: SignatureScheme, + clientCAScheme: SignatureScheme, + clientTLSScheme: SignatureScheme, + cipherSuitesServer: Array = CORDA_TLS_CIPHER_SUITES, + cipherSuitesClient: Array = CORDA_TLS_CIPHER_SUITES + ): Pair { + + val trustStorePath = tempFile("cordaTrustStore.jks") + val serverTLSKeyStorePath = tempFile("serversslkeystore.jks") + val clientTLSKeyStorePath = tempFile("clientsslkeystore.jks") + + // ROOT CA key and cert. + val rootCAKeyPair = Crypto.generateKeyPair(rootCAScheme) + val rootCACert = X509Utilities.createSelfSignedCACertificate(ROOT_X500, rootCAKeyPair) + + // Intermediate CA key and cert. + val intermediateCAKeyPair = Crypto.generateKeyPair(intermediateCAScheme) + val intermediateCACert = X509Utilities.createCertificate( + CertificateType.INTERMEDIATE_CA, + rootCACert, + rootCAKeyPair, + INTERMEDIATE_X500, + intermediateCAKeyPair.public + ) + + // Client 1 keys, certs and SSLKeyStore. + val serverCAKeyPair = Crypto.generateKeyPair(serverCAScheme) + val serverCACert = X509Utilities.createCertificate( + CertificateType.NODE_CA, + intermediateCACert, + intermediateCAKeyPair, + CLIENT_1_X500.x500Principal, + serverCAKeyPair.public + ) + + val serverTLSKeyPair = Crypto.generateKeyPair(serverTLSScheme) + val serverTLSCert = X509Utilities.createCertificate( + CertificateType.TLS, + serverCACert, + serverCAKeyPair, + CLIENT_1_X500.x500Principal, + serverTLSKeyPair.public + ) + + val serverTLSKeyStore = loadOrCreateKeyStore(serverTLSKeyStorePath, PASSWORD) + serverTLSKeyStore.addOrReplaceKey( + X509Utilities.CORDA_CLIENT_TLS, + serverTLSKeyPair.private, + PASSWORD.toCharArray(), + arrayOf(serverTLSCert, serverCACert, intermediateCACert, rootCACert)) + // serverTLSKeyStore.save(serverTLSKeyStorePath, PASSWORD) + val serverTLSKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) + serverTLSKeyManagerFactory.init(serverTLSKeyStore, PASSWORD.toCharArray()) + + // Client 2 keys, certs and SSLKeyStore. + val clientCAKeyPair = Crypto.generateKeyPair(clientCAScheme) + val clientCACert = X509Utilities.createCertificate( + CertificateType.NODE_CA, + intermediateCACert, + intermediateCAKeyPair, + CLIENT_2_X500.x500Principal, + clientCAKeyPair.public + ) + + val clientTLSKeyPair = Crypto.generateKeyPair(clientTLSScheme) + val clientTLSCert = X509Utilities.createCertificate( + CertificateType.TLS, + clientCACert, + clientCAKeyPair, + CLIENT_2_X500.x500Principal, + clientTLSKeyPair.public + ) + + val clientTLSKeyStore = loadOrCreateKeyStore(clientTLSKeyStorePath, PASSWORD) + clientTLSKeyStore.addOrReplaceKey( + X509Utilities.CORDA_CLIENT_TLS, + clientTLSKeyPair.private, + PASSWORD.toCharArray(), + arrayOf(clientTLSCert, clientCACert, intermediateCACert, rootCACert)) + // clientTLSKeyStore.save(clientTLSKeyStorePath, PASSWORD) + val clientTLSKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) + clientTLSKeyManagerFactory.init(clientTLSKeyStore, PASSWORD.toCharArray()) + + val trustStore = loadOrCreateKeyStore(trustStorePath, PASSWORD) + trustStore.addOrReplaceCertificate(X509Utilities.CORDA_ROOT_CA, rootCACert) + trustStore.addOrReplaceCertificate(X509Utilities.CORDA_INTERMEDIATE_CA, intermediateCACert) + // trustStore.save(trustStorePath, PASSWORD) + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + trustManagerFactory.init(trustStore) + + return Pair( + SslContextBuilder + .forServer(serverTLSKeyManagerFactory) + .trustManager(trustManagerFactory) + .ciphers(cipherSuitesServer.toMutableList()) + .clientAuth(ClientAuth.REQUIRE) + .protocols("TLSv1.2") + .sslProvider(if (sslSetup.serverNative) SslProvider.OPENSSL else SslProvider.JDK) + .build(), + SslContextBuilder + .forClient() + .keyManager(clientTLSKeyManagerFactory) + .trustManager(trustManagerFactory) + .ciphers(cipherSuitesClient.toMutableList()) + .protocols("TLSv1.2") + .sslProvider(if (sslSetup.clientNative) SslProvider.OPENSSL else SslProvider.JDK) + .build() + ) + } + + private fun testConnect(serverContext: SslContext, clientContext: SslContext, expectedCipherSuite: String) { + val serverHandler = NettyTestHandler { ctx, msg -> ctx?.writeAndFlush(msg) } + val clientHandler = NettyTestHandler { _, msg -> assertEquals("Hello!", NettyTestHandler.readString(msg)) } + + NettyTestServer(serverContext, serverHandler, portAllocation.nextPort()).use { server -> + server.start() + NettyTestClient(clientContext, InetAddress.getLocalHost().canonicalHostName, server.port, clientHandler).use { client -> + client.start() + + clientHandler.writeString("Hello!") + val readCalled = clientHandler.waitForReadCalled() + clientHandler.rethrowIfFailed() + serverHandler.rethrowIfFailed() + assertEquals(1, serverHandler.readCalledCounter) + assertEquals(1, clientHandler.readCalledCounter) + assertTrue(readCalled) + + assertEquals(expectedCipherSuite, client.engine!!.session.cipherSuite) + } + } + } +} \ No newline at end of file diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/NettyTestClient.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/NettyTestClient.kt new file mode 100644 index 0000000000..532b7d1c94 --- /dev/null +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/NettyTestClient.kt @@ -0,0 +1,99 @@ +package net.corda.testing.internal + +import io.netty.bootstrap.Bootstrap +import io.netty.channel.ChannelFuture +import io.netty.channel.ChannelInboundHandlerAdapter +import io.netty.handler.ssl.SslContext +import io.netty.channel.ChannelInitializer +import io.netty.channel.ChannelOption +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioSocketChannel +import io.netty.handler.ssl.SslHandler +import java.io.Closeable +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException +import java.util.concurrent.locks.ReentrantLock +import javax.net.ssl.SSLEngine +import kotlin.concurrent.thread + + +class NettyTestClient( + val sslContext: SslContext?, + val targetHost: String, + val targetPort: Int, + val handler: ChannelInboundHandlerAdapter +) : Closeable { + internal var mainThread: Thread? = null + internal var channelFuture: ChannelFuture? = null + + // lock/condition to make sure that start only returns when the server is actually running + private val lock = ReentrantLock() + private val condition = lock.newCondition() + + var engine: SSLEngine? = null + private set + + fun start() { + try { + lock.lock() + mainThread = thread(start = true) { run() } + if (!condition.await(5, TimeUnit.SECONDS)) { + throw TimeoutException("Netty test server failed to start") + } + } finally { + lock.unlock() + + } + } + + private fun run() { + // Configure the client. + val group = NioEventLoopGroup() + try { + val b = Bootstrap() + b.group(group) + .channel(NioSocketChannel::class.java) + .option(ChannelOption.TCP_NODELAY, true) + .handler(object : ChannelInitializer() { + @Throws(Exception::class) + public override fun initChannel(ch: SocketChannel) { + val p = ch.pipeline() + if (sslContext != null) { + engine = sslContext.newEngine(ch.alloc(), targetHost, targetPort) + p.addLast(SslHandler(engine)) + } + //p.addLast(new LoggingHandler(LogLevel.INFO)); + p.addLast(handler) + } + }) + + // Start the client. + val f = b.connect(targetHost, targetPort) + try { + lock.lock() + condition.signal() + channelFuture = f.sync() + } finally { + lock.unlock() + } + + // Wait until the connection is closed. + f.channel().closeFuture().sync() + } finally { + // Shut down the event loop to terminate all threads. + group.shutdownGracefully() + } + } + + fun stop() { + channelFuture?.channel()?.close() + mainThread?.join() + mainThread = null + channelFuture = null + } + + override fun close() { + stop() + } +} \ No newline at end of file diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/NettyTestHandler.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/NettyTestHandler.kt new file mode 100644 index 0000000000..b62ed09805 --- /dev/null +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/NettyTestHandler.kt @@ -0,0 +1,74 @@ +package net.corda.testing.internal + +import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled +import io.netty.channel.Channel +import io.netty.channel.ChannelDuplexHandler +import io.netty.channel.ChannelHandlerContext +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + +class NettyTestHandler(val onMessageFunc: (ctx: ChannelHandlerContext?, msg: Any?) -> Unit = { _, _ -> }) : ChannelDuplexHandler() { + private var channel: Channel? = null + private var failure: Throwable? = null + + private val lock = ReentrantLock() + private val condition = lock.newCondition() + + var readCalledCounter: Int = 0 + private set + + override fun channelRegistered(ctx: ChannelHandlerContext?) { + channel = ctx?.channel() + super.channelRegistered(ctx) + } + + override fun channelRead(ctx: ChannelHandlerContext?, msg: Any?) { + try { + lock.lock() + readCalledCounter++ + onMessageFunc(ctx, msg) + } catch( e: Throwable ){ + failure = e + } finally { + condition.signal() + lock.unlock() + } + } + + fun writeString(msg: String) { + val buffer = Unpooled.wrappedBuffer(msg.toByteArray()) + require(channel != null) { "Channel must be registered before sending messages" } + channel!!.writeAndFlush(buffer) + } + + fun rethrowIfFailed() { + failure?.also { throw it } + } + + fun waitForReadCalled(numberOfExpectedCalls: Int = 1): Boolean { + try { + lock.lock() + if (readCalledCounter >= numberOfExpectedCalls) { + return true + } + while (readCalledCounter < numberOfExpectedCalls) { + if (!condition.await(5, TimeUnit.SECONDS)) { + return false + } + } + return true + } finally { + lock.unlock() + } + } + + companion object { + fun readString(buffer: Any?): String { + checkNotNull(buffer) + val ar = ByteArray((buffer as ByteBuf).readableBytes()) + buffer.readBytes(ar) + return String(ar) + } + } +} \ No newline at end of file diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/NettyTestServer.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/NettyTestServer.kt new file mode 100644 index 0000000000..a9effdd279 --- /dev/null +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/NettyTestServer.kt @@ -0,0 +1,97 @@ +package net.corda.testing.internal + +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.ChannelFuture +import io.netty.channel.ChannelInboundHandlerAdapter +import io.netty.channel.ChannelInitializer +import io.netty.channel.ChannelOption +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.handler.logging.LogLevel +import io.netty.handler.logging.LoggingHandler +import io.netty.handler.ssl.SslContext +import java.io.Closeable +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException +import java.util.concurrent.locks.Condition +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.thread + + +class NettyTestServer( + private val sslContext: SslContext?, + val messageHandler: ChannelInboundHandlerAdapter, + val port: Int +) : Closeable { + internal var mainThread: Thread? = null + internal var channel: ChannelFuture? = null + + // lock/condition to make sure that start only returns when the server is actually running + val lock = ReentrantLock() + val condition: Condition = lock.newCondition() + + fun start() { + try { + lock.lock() + mainThread = thread(start = true) { run() } + if (!condition.await(5, TimeUnit.SECONDS)) { + throw TimeoutException("Netty test server failed to start") + } + } finally { + lock.unlock() + } + } + + fun run() { + // Configure the server. + val bossGroup = NioEventLoopGroup(1) + val workerGroup = NioEventLoopGroup() + try { + val b = ServerBootstrap() + b.group(bossGroup, workerGroup) + .channel(NioServerSocketChannel::class.java) + .option(ChannelOption.SO_BACKLOG, 100) + .handler(LoggingHandler(LogLevel.INFO)) + .childHandler(object : ChannelInitializer() { + @Throws(Exception::class) + public override fun initChannel(ch: SocketChannel) { + val p = ch.pipeline() + if (sslContext != null) { + p.addLast(sslContext.newHandler(ch.alloc())) + } + //p.addLast(new LoggingHandler(LogLevel.INFO)); + p.addLast(messageHandler) + } + }) + + // Start the server. + val f = b.bind(port) + try { + lock.lock() + channel = f.sync() + condition.signal() + } finally { + lock.unlock() + } + + // Wait until the server socket is closed. + channel!!.channel().closeFuture().sync() + } finally { + // Shut down all event loops to terminate all threads. + bossGroup.shutdownGracefully() + workerGroup.shutdownGracefully() + } + } + + fun stop() { + channel?.channel()?.close() + mainThread?.join() + channel = null + mainThread = null + } + + override fun close() { + stop() + } +} \ No newline at end of file diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/stubs/CertificateStoreStubs.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/stubs/CertificateStoreStubs.kt index 62f871ca8c..4de8c701b2 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/stubs/CertificateStoreStubs.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/stubs/CertificateStoreStubs.kt @@ -42,11 +42,11 @@ class CertificateStoreStubs { companion object { @JvmStatic - fun withCertificatesDirectory(certificatesDirectory: Path, keyStoreFileName: String = KeyStore.DEFAULT_STORE_FILE_NAME, keyStorePassword: String = KeyStore.DEFAULT_STORE_PASSWORD, trustStoreFileName: String = TrustStore.DEFAULT_STORE_FILE_NAME, trustStorePassword: String = TrustStore.DEFAULT_STORE_PASSWORD): MutualSslConfiguration { + fun withCertificatesDirectory(certificatesDirectory: Path, keyStoreFileName: String = KeyStore.DEFAULT_STORE_FILE_NAME, keyStorePassword: String = KeyStore.DEFAULT_STORE_PASSWORD, trustStoreFileName: String = TrustStore.DEFAULT_STORE_FILE_NAME, trustStorePassword: String = TrustStore.DEFAULT_STORE_PASSWORD, useOpenSsl: Boolean = false): MutualSslConfiguration { val keyStore = FileBasedCertificateStoreSupplier(certificatesDirectory / keyStoreFileName, keyStorePassword) val trustStore = FileBasedCertificateStoreSupplier(certificatesDirectory / trustStoreFileName, trustStorePassword) - return SslConfiguration.mutual(keyStore, trustStore) + return SslConfiguration.mutual(keyStore, trustStore, useOpenSsl) } @JvmStatic diff --git a/testing/test-utils/src/test/kotlin/net/corda/testing/internal/TestNettyTestInfra.kt b/testing/test-utils/src/test/kotlin/net/corda/testing/internal/TestNettyTestInfra.kt new file mode 100644 index 0000000000..a16550ec4c --- /dev/null +++ b/testing/test-utils/src/test/kotlin/net/corda/testing/internal/TestNettyTestInfra.kt @@ -0,0 +1,88 @@ +package net.corda.testing.internal + +import io.netty.channel.ChannelInboundHandlerAdapter +import org.junit.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class TestNettyTestInfra { + @Test + fun testStartAndStopServer() { + val testHandler = rigorousMock() + NettyTestServer(null, testHandler, 56234).use { server -> + server.start() + assertNotNull(server.mainThread) + assertNotNull(server.channel) + } + } + + @Test + fun testStartAndStopClient() { + val serverHandler = ChannelInboundHandlerAdapter() + val clientHandler = ChannelInboundHandlerAdapter() + + NettyTestServer(null, serverHandler, 56234).use { server -> + server.start() + + NettyTestClient(null, "localhost", 56234, clientHandler).use { client -> + client.start() + assertNotNull(client.mainThread) + assertNotNull(client.channelFuture) + } + } + } + + @Test + fun testPingPong() { + val serverHandler = NettyTestHandler { ctx, msg -> + ctx?.writeAndFlush(msg) + } + val clientHandler = NettyTestHandler { _, msg -> + assertEquals("ping", NettyTestHandler.readString(msg)) + } + NettyTestServer(null, serverHandler, 56234).use { server -> + server.start() + + NettyTestClient(null, "localhost", 56234, clientHandler).use { client -> + client.start() + + clientHandler.writeString("ping") + assertTrue(clientHandler.waitForReadCalled(1)) + clientHandler.rethrowIfFailed() + assertEquals(1, clientHandler.readCalledCounter) + assertEquals(1, serverHandler.readCalledCounter) + } + } + } + + @Test + fun testFailureHandling() { + val serverHandler = NettyTestHandler { ctx, msg -> + ctx?.writeAndFlush(msg) + } + val clientHandler = NettyTestHandler { _, msg -> + assertEquals("pong", NettyTestHandler.readString(msg)) + } + NettyTestServer(null, serverHandler, 56234).use { server -> + server.start() + + NettyTestClient(null, "localhost", 56234, clientHandler).use { client -> + client.start() + + clientHandler.writeString("ping") + assertTrue(clientHandler.waitForReadCalled(1)) + var exceptionThrown = false + try { + clientHandler.rethrowIfFailed() + } catch (e: AssertionError) { + exceptionThrown = true + } + assertTrue(exceptionThrown, "Expected assertion failure has not been thrown") + assertEquals(1, serverHandler.readCalledCounter) + assertEquals(1, clientHandler.readCalledCounter) + } + } + } + +} \ No newline at end of file