ENT-1565 Enable the use of BoringSSL (#1358)

* BoringSsl dependency

* Merge over boring_ssl changes

* Merge over boring_ssl changes

*  Upgrade netty-tcnative (and netty to compatible version)

* Add openSSL flag to SSLConfiguration and implementations.

* Make SSL implementation switchable for Artemis

* Parameterize AMQP bridge tests on use of openSSL

* Plumb through open SSL flag to AMQP client/server.

* Add open ssl flag to reference.conf

* Slight clean-up

* Add LoggingTrustManagerWrapper for OpenSsl contexts

* Remove unneeded lazy and check for double wrapping

* Fix TrustMangerWrapper and test, clean-up

* Add key factory wrapper to get the current certificate chain out.

* Use cert chain returning key mananager factory to get local cert

* Force consistent netty-tcnative version across all dependencies

* Make proton wrapper tests check all combinations of client/server native/java SSL

* Add test netty server/client to run SSL tests with

* Simplify usage of test netty components and clean up

* Improve exception handling in NettyTestHandler

* Add openSSL test for X509UtilitiesTests

* Expose engine for test usage

* Add the X509 peer chain check from the socket based test

* Port of TLSAuthenticationTests to use Netty so we can use different SSL providers, add boringSSL tests

* Adapt tests to new config structure

* Readd `useOpenSsl` configuration

* Readd `useOpenSsl` configuration

* Fix up ArtemisTransport for OpenSSL plus tests

* Adapt auth tests

* Formatting

* Remove obsolte file

* Fix config misnomer

* Add SNI host logic to OpenSSL execution branch

* Remove TLS_DHE_RSA tests

* Make exception handling in the netty test infra deterministic
This commit is contained in:
Christian Sailer 2018-10-01 13:59:52 +01:00 committed by GitHub
parent 9979035280
commit 532d95ccac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1267 additions and 51 deletions

View File

@ -19,7 +19,8 @@ data class BridgeSSLConfigurationImpl(private val sslKeystore: Path,
private val keyStorePassword: String,
private val trustStoreFile: Path,
private val trustStorePassword: String,
private val crlCheckSoftFail: Boolean) : BridgeSSLConfiguration {
private val crlCheckSoftFail: Boolean,
override val useOpenSsl: Boolean = false) : BridgeSSLConfiguration {
override val keyStore = FileBasedCertificateStoreSupplier(sslKeystore, keyStorePassword)
override val trustStore = FileBasedCertificateStoreSupplier(trustStoreFile, trustStorePassword)

View File

@ -37,7 +37,8 @@ buildscript {
ext.metrics_version = constants.getProperty("metricsVersion")
ext.metrics_new_relic_version = constants.getProperty("metricsNewRelicVersion")
ext.okhttp_version = '3.5.0'
ext.netty_version = '4.1.22.Final'
ext.netty_version = '4.1.29.Final'
ext.tcnative_version = '2.0.14.Final'
ext.typesafe_config_version = constants.getProperty("typesafeConfigVersion")
ext.fileupload_version = '1.3.3'
ext.junit_version = '4.12'
@ -264,7 +265,11 @@ allprojects {
// Demand that everything uses our given version of Netty.
eachDependency { details ->
if (details.requested.group == 'io.netty' && details.requested.name.startsWith('netty-')) {
details.useVersion netty_version
if (details.requested.name.startsWith('netty-tcnative')){
details.useVersion tcnative_version
} else {
details.useVersion netty_version
}
}
}
}

View File

@ -57,6 +57,7 @@ dependencies {
testCompile "org.assertj:assertj-core:$assertj_version"
testCompile "org.jetbrains.kotlin:kotlin-test:$kotlin_version"
testCompile project(':node-driver')
testCompile project(':test-utils')
compile ("org.apache.activemq:artemis-amqp-protocol:${artemis_version}") {
// Gains our proton-j version from core module.

View File

@ -100,30 +100,32 @@ class ArtemisTcpTransport {
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true): TransportConfiguration {
return p2pAcceptorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL)
return p2pAcceptorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL, useOpenSsl = config?.useOpenSsl ?: false)
}
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true): TransportConfiguration {
return p2pConnectorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL)
return p2pConnectorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL, useOpenSsl = config?.useOpenSsl ?: false)
}
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true): TransportConfiguration {
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true, useOpenSsl: Boolean = false): TransportConfiguration {
val options = defaultArtemisOptions(hostAndPort).toMutableMap()
if (enableSSL) {
options.putAll(defaultSSLOptions)
(keyStore to trustStore).addToTransportOptions(options)
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
}
return TransportConfiguration(acceptorFactoryClassName, options)
}
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true): TransportConfiguration {
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true, useOpenSsl: Boolean = false): TransportConfiguration {
val options = defaultArtemisOptions(hostAndPort).toMutableMap()
if (enableSSL) {
options.putAll(defaultSSLOptions)
(keyStore to trustStore).addToTransportOptions(options)
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
}
return TransportConfiguration(connectorFactoryClassName, options)
}

View File

@ -43,8 +43,13 @@ class AMQPBridgeManager(config: MutualSslConfiguration, socksProxyConfig: SocksP
private class AMQPConfigurationImpl private constructor(override val keyStore: CertificateStore,
override val trustStore: CertificateStore,
override val socksProxyConfig: SocksProxyConfig?,
override val maxMessageSize: Int) : AMQPConfiguration {
constructor(config: MutualSslConfiguration, socksProxyConfig: SocksProxyConfig?, maxMessageSize: Int) : this(config.keyStore.get(), config.trustStore.get(), socksProxyConfig, maxMessageSize)
override val maxMessageSize: Int,
override val useOpenSsl: Boolean) : AMQPConfiguration {
constructor(config: MutualSslConfiguration, socksProxyConfig: SocksProxyConfig?, maxMessageSize: Int) : this(config.keyStore.get(),
config.trustStore.get(),
socksProxyConfig,
maxMessageSize,
config.useOpenSsl)
}
private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(config, socksProxyConfig, maxMessageSize)

View File

@ -4,12 +4,13 @@ interface SslConfiguration {
val keyStore: FileBasedCertificateStoreSupplier?
val trustStore: FileBasedCertificateStoreSupplier?
val useOpenSsl: Boolean
companion object {
fun mutual(keyStore: FileBasedCertificateStoreSupplier, trustStore: FileBasedCertificateStoreSupplier): MutualSslConfiguration {
fun mutual(keyStore: FileBasedCertificateStoreSupplier, trustStore: FileBasedCertificateStoreSupplier, useOpenSsl: Boolean = false ): MutualSslConfiguration {
return MutualSslOptions(keyStore, trustStore)
return MutualSslOptions(keyStore, trustStore, useOpenSsl)
}
}
}
@ -20,4 +21,4 @@ interface MutualSslConfiguration : SslConfiguration {
override val trustStore: FileBasedCertificateStoreSupplier
}
private class MutualSslOptions(override val keyStore: FileBasedCertificateStoreSupplier, override val trustStore: FileBasedCertificateStoreSupplier) : MutualSslConfiguration
private class MutualSslOptions(override val keyStore: FileBasedCertificateStoreSupplier, override val trustStore: FileBasedCertificateStoreSupplier, override val useOpenSsl: Boolean ) : MutualSslConfiguration

View File

@ -34,6 +34,7 @@ import javax.net.ssl.SSLException
*/
internal class AMQPChannelHandler(private val serverMode: Boolean,
private val allowedRemoteLegalNames: Set<CordaX500Name>?,
private var keyManagerFactory: CertHoldingKeyManagerFactoryWrapper,
private val userName: String?,
private val password: String?,
private val trace: Boolean,
@ -45,11 +46,11 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
}
private lateinit var remoteAddress: InetSocketAddress
private var localCert: X509Certificate? = null
private var remoteCert: X509Certificate? = null
private var eventProcessor: EventProcessor? = null
private var suppressClose: Boolean = false
private var badCert: Boolean = false
private var localCert: X509Certificate? = null
private fun withMDC(block: () -> Unit) {
val oldMDC = MDC.getCopyOfContextMap()
@ -122,7 +123,18 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
if (evt is SslHandshakeCompletionEvent) {
if (evt.isSuccess) {
val sslHandler = ctx.pipeline().get(SslHandler::class.java)
localCert = sslHandler.engine().session.localCertificates[0].x509
val sslSession = sslHandler.engine().session
localCert = keyManagerFactory.getCurrentCertChain()?.get(0)
if (localCert == null) {
log.error("SSL KeyManagerFactory failed to provide a local cert")
ctx.close()
return
}
if (sslSession.peerCertificates == null || sslSession.peerCertificates.isEmpty()) {
log.error("No peer certificates")
ctx.close()
return
}
remoteCert = sslHandler.engine().session.peerCertificates[0].x509
val remoteX500Name = try {
CordaX500Name.build(remoteCert!!.subjectX500Principal)
@ -151,7 +163,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
} else {
badCert = true
}
logErrorWithMDC("Handshake failure ${evt.cause().message}")
logErrorWithMDC("Handshake failure: ${evt.cause().message}")
if (log.isTraceEnabled) {
withMDC { log.trace("Handshake failure", evt.cause()) }
}

View File

@ -22,6 +22,7 @@ import rx.Observable
import rx.subjects.PublishSubject
import java.lang.Long.min
import java.net.InetSocketAddress
import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
@ -157,12 +158,18 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
}
}
val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory)
val target = parent.currentTarget
val handler = createClientSslHelper(target, parent.allowedRemoteLegalNames, keyManagerFactory, trustManagerFactory)
val handler = if (parent.configuration.useOpenSsl){
createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc())
} else {
createClientSslHelper(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory)
}
pipeline.addLast("sslHandler", handler)
if (conf.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO))
pipeline.addLast(AMQPChannelHandler(false,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
conf.userName,
conf.password,
conf.trace,

View File

@ -55,5 +55,12 @@ interface AMQPConfiguration {
@JvmDefault
val socksProxyConfig: SocksProxyConfig?
get() = null
/**
* Whether to use the tcnative open/boring SSL provider or the default Java SSL provider
*/
@JvmDefault
val useOpenSsl: Boolean
get() = false
}

View File

@ -23,6 +23,7 @@ import rx.Observable
import rx.subjects.PublishSubject
import java.net.BindException
import java.net.InetSocketAddress
import java.security.cert.X509Certificate
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
@ -66,11 +67,17 @@ class AMQPServer(val hostName: String,
override fun initChannel(ch: SocketChannel) {
val pipeline = ch.pipeline()
val handler = createServerSslHelper(keyManagerFactory, trustManagerFactory)
val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory)
val handler = if (parent.configuration.useOpenSsl){
createServerOpenSslHandler(wrappedKeyManagerFactory, trustManagerFactory, ch.alloc())
} else {
createServerSslHelper(wrappedKeyManagerFactory, trustManagerFactory)
}
pipeline.addLast("sslHandler", handler)
if (conf.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO))
pipeline.addLast(AMQPChannelHandler(true,
null,
wrappedKeyManagerFactory,
conf.userName,
conf.password,
conf.trace,

View File

@ -0,0 +1,94 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import java.net.Socket
import java.security.Principal
import java.security.PrivateKey
import java.security.cert.X509Certificate
import javax.net.ssl.SSLEngine
import javax.net.ssl.X509ExtendedKeyManager
import javax.net.ssl.X509KeyManager
interface AliasProvidingKeyMangerWrapper : X509KeyManager {
var lastAlias: String?
}
class AliasProvidingKeyMangerWrapperImpl(private val keyManager: X509KeyManager) : AliasProvidingKeyMangerWrapper {
override var lastAlias: String? = null
override fun getClientAliases(p0: String?, p1: Array<out Principal>?): Array<String> {
return keyManager.getClientAliases(p0, p1)
}
override fun getServerAliases(p0: String?, p1: Array<out Principal>?): Array<String> {
return getServerAliases(p0, p1)
}
override fun chooseServerAlias(p0: String?, p1: Array<out Principal>?, p2: Socket?): String? {
return storeIfNotNull { keyManager.chooseServerAlias(p0, p1, p2) }
}
override fun getCertificateChain(p0: String?): Array<X509Certificate> {
return keyManager.getCertificateChain(p0)
}
override fun getPrivateKey(p0: String?): PrivateKey {
return keyManager.getPrivateKey(p0)
}
override fun chooseClientAlias(p0: Array<out String>?, p1: Array<out Principal>?, p2: Socket?): String? {
return storeIfNotNull { keyManager.chooseClientAlias(p0, p1, p2) }
}
private fun storeIfNotNull(func: () -> String?): String? {
val alias = func()
if (alias != null) {
lastAlias = alias
}
return alias
}
}
class AliasProvidingExtendedKeyMangerWrapper(private val keyManager: X509ExtendedKeyManager) : X509ExtendedKeyManager(), AliasProvidingKeyMangerWrapper {
override var lastAlias: String? = null
override fun getClientAliases(p0: String?, p1: Array<out Principal>?): Array<String> {
return keyManager.getClientAliases(p0, p1)
}
override fun getServerAliases(p0: String?, p1: Array<out Principal>?): Array<String> {
return keyManager.getServerAliases(p0, p1)
}
override fun chooseServerAlias(p0: String?, p1: Array<out Principal>?, p2: Socket?): String? {
return storeIfNotNull { keyManager.chooseServerAlias(p0, p1, p2) }
}
override fun getCertificateChain(p0: String?): Array<X509Certificate> {
return keyManager.getCertificateChain(p0)
}
override fun getPrivateKey(p0: String?): PrivateKey {
return keyManager.getPrivateKey(p0)
}
override fun chooseClientAlias(p0: Array<out String>?, p1: Array<out Principal>?, p2: Socket?): String? {
return storeIfNotNull { keyManager.chooseClientAlias(p0, p1, p2) }
}
override fun chooseEngineClientAlias(p0: Array<out String>?, p1: Array<out Principal>?, p2: SSLEngine?): String? {
return storeIfNotNull { keyManager.chooseEngineClientAlias(p0, p1, p2) }
}
override fun chooseEngineServerAlias(p0: String?, p1: Array<out Principal>?, p2: SSLEngine?): String? {
return storeIfNotNull { keyManager.chooseEngineServerAlias(p0, p1, p2) }
}
private fun storeIfNotNull(func: () -> String?): String? {
val alias = func()
if (alias != null) {
lastAlias = alias
}
return alias
}
}

View File

@ -0,0 +1,66 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import java.security.KeyStore
import java.security.cert.X509Certificate
import javax.net.ssl.*
class CertHoldingKeyManagerFactorySpiWrapper(private val factorySpi: KeyManagerFactorySpi) : KeyManagerFactorySpi() {
override fun engineInit(p0: KeyStore?, p1: CharArray?) {
val engineInitMethod = KeyManagerFactorySpi::class.java.getDeclaredMethod("engineInit", KeyStore::class.java, CharArray::class.java)
engineInitMethod.isAccessible = true
engineInitMethod.invoke(factorySpi, p0, p1)
}
override fun engineInit(p0: ManagerFactoryParameters?) {
val engineInitMethod = KeyManagerFactorySpi::class.java.getDeclaredMethod("engineInit", ManagerFactoryParameters::class.java)
engineInitMethod.isAccessible = true
engineInitMethod.invoke(factorySpi, p0)
}
private fun getKeyManagersImpl(): Array<KeyManager> {
val engineGetKeyManagersMethod = KeyManagerFactorySpi::class.java.getDeclaredMethod("engineGetKeyManagers")
engineGetKeyManagersMethod.isAccessible = true
@Suppress("UNCHECKED_CAST")
val keyManagers = engineGetKeyManagersMethod.invoke(factorySpi) as Array<KeyManager>
return if (factorySpi is CertHoldingKeyManagerFactorySpiWrapper) keyManagers else keyManagers.mapNotNull {
@Suppress("USELESS_CAST") // the casts to KeyManager are not useless - without them, the typed array will be of type Any
when (it) {
is X509ExtendedKeyManager -> AliasProvidingExtendedKeyMangerWrapper(it) as KeyManager
is X509KeyManager -> AliasProvidingKeyMangerWrapperImpl(it) as KeyManager
else -> null
}
}.toTypedArray()
}
private val keyManagers = lazy { getKeyManagersImpl() }
override fun engineGetKeyManagers(): Array<KeyManager> {
return keyManagers.value
}
}
/**
* You can wrap a key manager factory in this class if you need to get the cert chain currently used to identify or
* verify. When using for TLS channels, make sure to wrap the (singleton) factory separately on each channel, as
* the wrapper is not thread safe as in it will return the last used alias/cert chain and has itself no notion
* of belonging to a certain channel.
*/
class CertHoldingKeyManagerFactoryWrapper(factory: KeyManagerFactory) : KeyManagerFactory(getFactorySpi(factory), factory.provider, factory.algorithm) {
companion object {
private fun getFactorySpi(factory: KeyManagerFactory): KeyManagerFactorySpi {
val spiField = KeyManagerFactory::class.java.getDeclaredField("factorySpi")
spiField.isAccessible = true
return CertHoldingKeyManagerFactorySpiWrapper(spiField.get(factory) as KeyManagerFactorySpi)
}
}
fun getCurrentCertChain(): Array<out X509Certificate>? {
val keyManager = keyManagers.firstOrNull()
val alias = if (keyManager is AliasProvidingKeyMangerWrapper) keyManager.lastAlias else null
return if (alias != null && keyManager is X509KeyManager) {
keyManager.getCertificateChain(alias)
} else null
}
}

View File

@ -1,6 +1,9 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.buffer.ByteBufAllocator
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.SslHandler
import io.netty.handler.ssl.SslProvider
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.CordaX500Name
@ -126,6 +129,23 @@ internal fun createClientSslHelper(target: NetworkHostAndPort,
return SslHandler(sslEngine)
}
internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler {
val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build()
val sslEngine = sslContext.newEngine(alloc, target.host, target.port)
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
sslEngine.enabledCipherSuites = ArtemisTcpTransport.CIPHER_SUITES.toTypedArray()
if (expectedRemoteLegalNames.size == 1) {
val sslParameters = sslEngine.sslParameters
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters
}
return SslHandler(sslEngine)
}
internal fun createServerSslHelper(keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
val sslContext = SSLContext.getInstance("TLS")
@ -159,6 +179,18 @@ internal fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateSto
return CertPathTrustManagerParameters(pkixParams)
}
internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler {
val sslContext = SslContextBuilder.forServer(keyManagerFactory).sslProvider(SslProvider.OPENSSL).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build()
val sslEngine = sslContext.newEngine(alloc)
sslEngine.useClientMode = false
sslEngine.needClientAuth = true
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
sslEngine.enabledCipherSuites = ArtemisTcpTransport.CIPHER_SUITES.toTypedArray()
return SslHandler(sslEngine)
}
fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.password.toCharArray())
fun TrustManagerFactory.init(trustStore: CertificateStore) = init(trustStore.value.internal)

View File

@ -0,0 +1,36 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import java.security.KeyStore
import javax.net.ssl.*
class LoggingTrustManagerFactorySpiWrapper(private val factorySpi: TrustManagerFactorySpi) : TrustManagerFactorySpi() {
override fun engineGetTrustManagers(): Array<TrustManager> {
val engineGetTrustManagersMethod = TrustManagerFactorySpi::class.java.getDeclaredMethod("engineGetTrustManagers")
engineGetTrustManagersMethod.isAccessible = true
@Suppress("UNCHECKED_CAST")
val trustManagers = engineGetTrustManagersMethod.invoke(factorySpi) as Array<TrustManager>
return if (factorySpi is LoggingTrustManagerFactorySpiWrapper) trustManagers else trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java).map { LoggingTrustManagerWrapper(it) }.toTypedArray()
}
override fun engineInit(p0: KeyStore?) {
val engineInitMethod = TrustManagerFactorySpi::class.java.getDeclaredMethod("engineInit", KeyStore::class.java)
engineInitMethod.isAccessible = true
engineInitMethod.invoke(factorySpi, p0)
}
override fun engineInit(p0: ManagerFactoryParameters?) {
val engineInitMethod = TrustManagerFactorySpi::class.java.getDeclaredMethod("engineInit", ManagerFactoryParameters::class.java)
engineInitMethod.isAccessible = true
engineInitMethod.invoke(factorySpi, p0)
}
}
class LoggingTrustManagerFactoryWrapper(factory: TrustManagerFactory) : TrustManagerFactory(getFactorySpi(factory), factory.provider, factory.algorithm) {
companion object {
private fun getFactorySpi(factory: TrustManagerFactory): TrustManagerFactorySpi {
val spiField = TrustManagerFactory::class.java.getDeclaredField("factorySpi")
spiField.isAccessible = true
return LoggingTrustManagerFactorySpiWrapper(spiField.get(factory) as TrustManagerFactorySpi)
}
}
}

View File

@ -1,5 +1,10 @@
package net.corda.nodeapi.internal.crypto
import io.netty.handler.ssl.ClientAuth
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.SslProvider
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.*
import net.corda.core.crypto.Crypto.COMPOSITE_KEY
import net.corda.core.crypto.Crypto.ECDSA_SECP256K1_SHA256
@ -28,8 +33,12 @@ import net.corda.serialization.internal.amqp.amqpMagic
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.TestIdentity
import net.corda.testing.internal.stubs.CertificateStoreStubs
import net.corda.testing.driver.PortAllocation
import net.corda.testing.internal.NettyTestClient
import net.corda.testing.internal.NettyTestHandler
import net.corda.testing.internal.NettyTestServer
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.corda.testing.internal.stubs.CertificateStoreStubs
import net.i2p.crypto.eddsa.EdDSAPrivateKey
import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.asn1.x509.*
@ -65,6 +74,9 @@ class X509UtilitiesTest {
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
)
val portAllocation = PortAllocation.Incremental(10000)
// We ensure that all of the algorithms are both used (at least once) as first and second in the following [Pair]s.
// We also add [DEFAULT_TLS_SIGNATURE_SCHEME] and [DEFAULT_IDENTITY_SIGNATURE_SCHEME] combinations for consistency.
val certChainSchemeCombinations = listOf(
@ -348,6 +360,65 @@ class X509UtilitiesTest {
assertTrue(done)
}
@Test
fun `create server cert and use in OpenSSL channel`() {
val sslConfig = CertificateStoreStubs.P2P.withCertificatesDirectory(tempFolder.root.toPath(), keyStorePassword = "serverstorepass")
val (rootCa, intermediateCa) = createDevIntermediateCaCertPath()
// Generate server cert and private key and populate another keystore suitable for SSL
sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore)
val sslServerContext = SslContextBuilder
.forServer(keyManagerFactory)
.trustManager(trustManagerFactory)
.clientAuth(ClientAuth.REQUIRE)
.ciphers(CIPHER_SUITES.toMutableList())
.sslProvider(SslProvider.OPENSSL)
.protocols("TLSv1.2")
.build()
val sslClientContext = SslContextBuilder
.forClient()
.keyManager(keyManagerFactory)
.trustManager(trustManagerFactory)
.ciphers(CIPHER_SUITES.toMutableList())
.sslProvider(SslProvider.OPENSSL)
.protocols("TLSv1.2")
.build()
val serverHandler = NettyTestHandler { ctx, msg -> ctx?.writeAndFlush(msg) }
val clientHandler = NettyTestHandler { _, msg -> assertEquals("Hello", NettyTestHandler.readString(msg)) }
NettyTestServer(sslServerContext, serverHandler, portAllocation.nextPort()).use { server ->
server.start()
NettyTestClient(sslClientContext, InetAddress.getLocalHost().canonicalHostName, server.port, clientHandler).use { client ->
client.start()
clientHandler.writeString("Hello")
val readCalled = clientHandler.waitForReadCalled()
clientHandler.rethrowIfFailed()
serverHandler.rethrowIfFailed()
assertTrue(readCalled)
assertEquals(1, serverHandler.readCalledCounter)
assertEquals(1, clientHandler.readCalledCounter)
val peerChain = client.engine!!.session.peerCertificates.x509
val peerX500Principal = peerChain[0].subjectX500Principal
assertEquals(MEGA_CORP.name.x500Principal, peerX500Principal)
X509Utilities.validateCertificateChain(rootCa.certificate, peerChain)
}
}
}
private fun tempFile(name: String): Path = tempFolder.root.toPath() / name
private fun MutualSslConfiguration.createTrustStore(rootCert: X509Certificate) {

View File

@ -0,0 +1,95 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.internal.div
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.internal.rigorousMock
import net.corda.testing.internal.stubs.CertificateStoreStubs
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.X509KeyManager
import kotlin.test.*
class TestKeyManagerFactoryWrapper {
@Rule
@JvmField
val temporaryFolder = TemporaryFolder()
private abstract class AbstractNodeConfiguration : NodeConfiguration
@Test
fun testWrapping() {
val baseDir = temporaryFolder.root.toPath() / "testWrapping"
val certDir = baseDir / "certificates"
val sslConfig = CertificateStoreStubs.P2P.withCertificatesDirectory(temporaryFolder.root.toPath(), keyStorePassword = "serverstorepass")
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(temporaryFolder.root.toPath())
val config = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(baseDir).whenever(it).baseDirectory
doReturn(certDir).whenever(it).certificatesDirectory
doReturn(ALICE_NAME).whenever(it).myLegalName
doReturn(sslConfig).whenever(it).p2pSslOptions
doReturn(signingCertificateStore).whenever(it).signingCertificateStore
}
config.configureWithDevSSLCertificate()
val underlyingKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(underlyingKeyManagerFactory)
wrappedKeyManagerFactory.init(config.p2pSslOptions.keyStore.get())
val keyManagers = wrappedKeyManagerFactory.keyManagers
assertFalse(keyManagers.isEmpty())
assertNull(wrappedKeyManagerFactory.getCurrentCertChain())
val keyManager = keyManagers.first() as X509KeyManager
val alias = keyManager.chooseClientAlias(arrayOf("EC_EC"), null, null)
assertNotNull(alias)
val certChain = wrappedKeyManagerFactory.getCurrentCertChain()
assertNotNull(certChain)
assertTrue(certChain!!.isNotEmpty())
assertEquals(alias, (keyManager as AliasProvidingKeyMangerWrapper).lastAlias)
}
@Test
fun testWrappingSeparately() {
val baseDir = temporaryFolder.root.toPath() / "testWrapping"
val certDir = baseDir / "certificates"
val sslConfig = CertificateStoreStubs.P2P.withCertificatesDirectory(temporaryFolder.root.toPath(), keyStorePassword = "serverstorepass")
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(temporaryFolder.root.toPath())
val config = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(baseDir).whenever(it).baseDirectory
doReturn(certDir).whenever(it).certificatesDirectory
doReturn(ALICE_NAME).whenever(it).myLegalName
doReturn(sslConfig).whenever(it).p2pSslOptions
doReturn(signingCertificateStore).whenever(it).signingCertificateStore
}
config.configureWithDevSSLCertificate()
val underlyingKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(underlyingKeyManagerFactory)
wrappedKeyManagerFactory.init(config.p2pSslOptions.keyStore.get())
val otherWrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(underlyingKeyManagerFactory)
val keyManagers = wrappedKeyManagerFactory.keyManagers
assertFalse(keyManagers.isEmpty())
assertNull(wrappedKeyManagerFactory.getCurrentCertChain())
val keyManager = keyManagers.first() as X509KeyManager
keyManager.chooseClientAlias(arrayOf("EC_EC"), null, null)
val certChain = wrappedKeyManagerFactory.getCurrentCertChain()
assertNotNull(certChain)
assertTrue(certChain!!.isNotEmpty())
assertNull(otherWrappedKeyManagerFactory.getCurrentCertChain())
}
}

View File

@ -0,0 +1,51 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.internal.div
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.internal.rigorousMock
import net.corda.testing.internal.stubs.CertificateStoreStubs
import org.junit.Assert.assertTrue
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import javax.net.ssl.TrustManagerFactory
class TestTrustManagerFactoryWrapper {
@Rule
@JvmField
val temporaryFolder = TemporaryFolder()
private abstract class AbstractNodeConfiguration : NodeConfiguration
@Test
fun testWrapping() {
val baseDir = temporaryFolder.root.toPath() / "testWrapping"
val certDir = baseDir / "certificates"
val sslConfig = CertificateStoreStubs.P2P.withCertificatesDirectory(temporaryFolder.root.toPath(), keyStorePassword = "serverstorepass")
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(temporaryFolder.root.toPath())
val config = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(baseDir).whenever(it).baseDirectory
doReturn(certDir).whenever(it).certificatesDirectory
doReturn(ALICE_NAME).whenever(it).myLegalName
doReturn(sslConfig).whenever(it).p2pSslOptions
doReturn(signingCertificateStore).whenever(it).signingCertificateStore
}
config.configureWithDevSSLCertificate()
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
val wrapped = LoggingTrustManagerFactoryWrapper(trustManagerFactory)
wrapped.init(initialiseTrustStoreAndEnableCrlChecking(config.p2pSslOptions.trustStore.get(), false))
val trustManagers = wrapped.trustManagers
assertTrue(trustManagers.size > 0)
assertTrue(trustManagers[0] is LoggingTrustManagerWrapper)
}
}

View File

@ -223,8 +223,12 @@ dependencies {
testCompile("io.netty:netty-example:$netty_version") {
exclude group: "io.netty", module: "netty-tcnative"
exclude group: "ch.qos.logback", module: "logback-classic"
}
// Adding native SSL library to allow using native SSL with Artemis and AMQP
compile "io.netty:netty-tcnative-boringssl-static:$tcnative_version"
testCompile(project(':test-cli'))
}

View File

@ -5,11 +5,9 @@ import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.crypto.toStringShort
import net.corda.core.internal.div
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.hours
import net.corda.core.utilities.loggerFor
import net.corda.node.services.config.EnterpriseConfiguration
import net.corda.node.services.config.MutualExclusionConfiguration
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.node.services.config.*
import net.corda.node.services.messaging.ArtemisMessagingServer
import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent
@ -23,8 +21,8 @@ import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.MAX_MESSAGE_SIZE
import net.corda.testing.core.TestIdentity
import net.corda.testing.driver.PortAllocation
import net.corda.testing.internal.stubs.CertificateStoreStubs
import net.corda.testing.internal.rigorousMock
import net.corda.testing.internal.stubs.CertificateStoreStubs
import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID
import org.apache.activemq.artemis.api.core.RoutingType
import org.apache.activemq.artemis.api.core.SimpleString
@ -34,12 +32,22 @@ import org.junit.Ignore
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import java.util.*
import kotlin.concurrent.thread
import kotlin.system.measureNanoTime
import kotlin.system.measureTimeMillis
import kotlin.test.assertEquals
class AMQPBridgeTest {
@RunWith(Parameterized::class)
class AMQPBridgeTest(private val useOpenSsl: Boolean) {
companion object {
@JvmStatic
@Parameterized.Parameters(name = "useOpenSsl = {0}")
fun data(): Collection<Boolean> = listOf(false, true)
}
@Rule
@JvmField
val temporaryFolder = TemporaryFolder()
@ -198,13 +206,30 @@ class AMQPBridgeTest {
var timeNanosCreateMessage = 0L
var timeNanosSendMessage = 0L
var timeMillisRead = 0L
val recThread = thread {
val current = artemisConsumer.receive()
val messageId = current.getIntProperty(P2PMessagingHeaders.senderUUID)
assertEquals(numReceived, messageId)
++numReceived
current.acknowledge()
timeMillisRead = measureTimeMillis {
while (numReceived < numMessages) {
val currentMsg = artemisConsumer.receive()
val loopMessageId = currentMsg.getIntProperty(P2PMessagingHeaders.senderUUID)
assertEquals(numReceived, loopMessageId)
++numReceived
currentMsg.acknowledge()
}
}
}
val simpleSourceQueueName = SimpleString(sourceQueueName)
val totalTimeMillis = measureTimeMillis {
repeat(numMessages) {
repeat(numMessages) { i ->
var artemisMessage: ClientMessage? = null
timeNanosCreateMessage += measureNanoTime {
artemisMessage = artemis.session.createMessage(true).apply {
putIntProperty("CountProp", it)
putIntProperty(P2PMessagingHeaders.senderUUID, i)
writeBodyBufferBytes(rubbishPayload)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
@ -215,18 +240,9 @@ class AMQPBridgeTest {
}
}
artemisClient.started!!.session.commit()
timeMillisRead = measureTimeMillis {
while (numReceived < numMessages) {
val current = artemisConsumer.receive()
val messageId = current.getIntProperty("CountProp")
assertEquals(numReceived, messageId)
++numReceived
current.acknowledge()
}
}
recThread.join(1.hours.toMillis())
}
println("Creating $numMessages messages took ${timeNanosCreateMessage / (1000 * 1000)} milliseconds")
println("Sending $numMessages messages took ${timeNanosSendMessage / (1000 * 1000)} milliseconds")
println("Receiving $numMessages messages took $timeMillisRead milliseconds")
@ -244,7 +260,7 @@ class AMQPBridgeTest {
private fun createArtemis(sourceQueueName: String?): Triple<ArtemisMessagingServer, ArtemisMessagingClient, BridgeManager> {
val baseDir = temporaryFolder.root.toPath() / "artemis"
val certificatesDirectory = baseDir / "certificates"
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = useOpenSsl)
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
val artemisConfig = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(baseDir).whenever(it).baseDirectory
@ -260,7 +276,8 @@ class AMQPBridgeTest {
artemisConfig.configureWithDevSSLCertificate()
val artemisServer = ArtemisMessagingServer(artemisConfig, artemisAddress.copy(host = "0.0.0.0"), MAX_MESSAGE_SIZE)
val artemisClient = ArtemisMessagingClient(artemisConfig.p2pSslOptions, artemisAddress, MAX_MESSAGE_SIZE)
val artemisClient = ArtemisMessagingClient(artemisConfig.p2pSslOptions, artemisAddress, MAX_MESSAGE_SIZE, confirmationWindowSize = artemisConfig.enterpriseConfiguration.tuning.p2pConfirmationWindowSize)
artemisServer.start()
artemisClient.start()
@ -279,7 +296,7 @@ class AMQPBridgeTest {
private fun createArtemisReceiver(targetAdress: NetworkHostAndPort, workingDir: String): Pair<ArtemisMessagingServer, ArtemisMessagingClient> {
val baseDir = temporaryFolder.root.toPath() / workingDir
val certificatesDirectory = baseDir / "certificates"
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = useOpenSsl)
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
val artemisConfig = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(baseDir).whenever(it).baseDirectory
@ -288,7 +305,9 @@ class AMQPBridgeTest {
doReturn(signingCertificateStore).whenever(it).signingCertificateStore
doReturn(p2pSslConfiguration).whenever(it).p2pSslOptions
doReturn(targetAdress).whenever(it).p2pAddress
doReturn("").whenever(it).jmxMonitoringHttpPort
doReturn(null).whenever(it).jmxMonitoringHttpPort
@Suppress("DEPRECATION")
doReturn(emptyList<CertChainPolicyConfig>()).whenever(it).certificateChainCheckPolicies
doReturn(EnterpriseConfiguration(MutualExclusionConfiguration(false, "", 20000, 40000))).whenever(it).enterpriseConfiguration
}
artemisConfig.configureWithDevSSLCertificate()
@ -304,7 +323,7 @@ class AMQPBridgeTest {
private fun createAMQPServer(maxMessageSize: Int = MAX_MESSAGE_SIZE): AMQPServer {
val baseDir = temporaryFolder.root.toPath() / "server"
val certificatesDirectory = baseDir / "certificates"
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = useOpenSsl)
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
val serverConfig = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(temporaryFolder.root.toPath() / "server").whenever(it).baseDirectory
@ -321,6 +340,7 @@ class AMQPBridgeTest {
override val trustStore = serverConfig.p2pSslOptions.trustStore.get()
override val trace: Boolean = true
override val maxMessageSize: Int = maxMessageSize
override val useOpenSsl = serverConfig.p2pSslOptions.useOpenSsl
}
return AMQPServer("0.0.0.0",
amqpAddress.port,

View File

@ -40,13 +40,32 @@ import org.junit.Assert.assertArrayEquals
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import java.security.cert.X509Certificate
import javax.net.ssl.*
import kotlin.concurrent.thread
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class ProtonWrapperTests {
@RunWith(Parameterized::class)
class ProtonWrapperTests(val sslSetup: SslSetup) {
companion object {
data class SslSetup(val clientNative: Boolean, val serverNative: Boolean) {
override fun toString(): String = "Client: ${if (clientNative) "openSsl" else "javaSsl"} Server: ${if (serverNative) "openSsl" else "javaSsl"} "
}
@JvmStatic
@Parameterized.Parameters(name = "{0}")
fun data(): Collection<SslSetup> = listOf(
SslSetup(false, false),
SslSetup(true, false),
SslSetup(false, true),
SslSetup(true, true)
)
}
@Rule
@JvmField
val temporaryFolder = TemporaryFolder()
@ -407,7 +426,7 @@ class ProtonWrapperTests {
val baseDirectory = temporaryFolder.root.toPath() / "artemis"
val certificatesDirectory = baseDirectory / "certificates"
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = sslSetup.serverNative)
val artemisConfig = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(baseDirectory).whenever(it).baseDirectory
doReturn(certificatesDirectory).whenever(it).certificatesDirectory
@ -432,7 +451,7 @@ class ProtonWrapperTests {
val baseDirectory = temporaryFolder.root.toPath() / "client"
val certificatesDirectory = baseDirectory / "certificates"
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = sslSetup.clientNative)
val clientConfig = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(baseDirectory).whenever(it).baseDirectory
doReturn(certificatesDirectory).whenever(it).certificatesDirectory
@ -450,6 +469,7 @@ class ProtonWrapperTests {
override val trustStore = clientTruststore
override val trace: Boolean = true
override val maxMessageSize: Int = maxMessageSize
override val useOpenSsl: Boolean = sslSetup.clientNative
}
return AMQPClient(
listOf(NetworkHostAndPort("localhost", serverPort),
@ -463,7 +483,7 @@ class ProtonWrapperTests {
val baseDirectory = temporaryFolder.root.toPath() / "client_%$id"
val certificatesDirectory = baseDirectory / "certificates"
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = sslSetup.clientNative)
val clientConfig = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(baseDirectory).whenever(it).baseDirectory
doReturn(certificatesDirectory).whenever(it).certificatesDirectory
@ -481,6 +501,7 @@ class ProtonWrapperTests {
override val trustStore = clientTruststore
override val trace: Boolean = true
override val maxMessageSize: Int = maxMessageSize
override val useOpenSsl: Boolean = sslSetup.clientNative
}
return AMQPClient(
listOf(NetworkHostAndPort("localhost", serverPort)),
@ -496,7 +517,7 @@ class ProtonWrapperTests {
val baseDirectory = temporaryFolder.root.toPath() / "server"
val certificatesDirectory = baseDirectory / "certificates"
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, useOpenSsl = sslSetup.serverNative)
val serverConfig = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(baseDirectory).whenever(it).baseDirectory
doReturn(certificatesDirectory).whenever(it).certificatesDirectory
@ -514,6 +535,7 @@ class ProtonWrapperTests {
override val trustStore = serverTruststore
override val trace: Boolean = true
override val maxMessageSize: Int = maxMessageSize
override val useOpenSsl: Boolean = sslSetup.serverNative
}
return AMQPServer(
"0.0.0.0",

View File

@ -280,7 +280,8 @@ data class NodeConfigurationImpl(
override val flowMonitorPeriodMillis: Duration = DEFAULT_FLOW_MONITOR_PERIOD_MILLIS,
override val flowMonitorSuspensionLoggingThresholdMillis: Duration = DEFAULT_FLOW_MONITOR_SUSPENSION_LOGGING_THRESHOLD_MILLIS,
override val cordappDirectories: List<Path> = listOf(baseDirectory / CORDAPPS_DIR_NAME_DEFAULT),
override val jmxReporterType: JmxReporterType? = JmxReporterType.JOLOKIA
override val jmxReporterType: JmxReporterType? = JmxReporterType.JOLOKIA,
private val useOpenSsl: Boolean = false
) : NodeConfiguration {
companion object {
private val logger = loggerFor<NodeConfigurationImpl>()
@ -313,7 +314,7 @@ data class NodeConfigurationImpl(
private val p2pKeyStore = FileBasedCertificateStoreSupplier(p2pKeystorePath, keyStorePassword)
private val p2pTrustStoreFilePath: Path get() = certificatesDirectory / "truststore.jks"
private val p2pTrustStore = FileBasedCertificateStoreSupplier(p2pTrustStoreFilePath, trustStorePassword)
override val p2pSslOptions: MutualSslConfiguration = SslConfiguration.mutual(p2pKeyStore, p2pTrustStore)
override val p2pSslOptions: MutualSslConfiguration = SslConfiguration.mutual(p2pKeyStore, p2pTrustStore, useOpenSsl)
override val rpcOptions: NodeRpcOptions
get() {

View File

@ -1,6 +1,7 @@
emailAddress = "admin@company.com"
keyStorePassword = "cordacadevpass"
trustStorePassword = "trustpass"
useOpenSsl = false
crlCheckSoftFail = true
lazyBridgeStart = true
additionalP2PAddresses = []

View File

@ -0,0 +1,317 @@
package net.corda.node.utilities
import io.netty.handler.ssl.ClientAuth
import io.netty.handler.ssl.SslContext
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.SslProvider
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SignatureScheme
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div
import net.corda.nodeapi.internal.crypto.*
import net.corda.testing.driver.PortAllocation
import net.corda.testing.internal.NettyTestClient
import net.corda.testing.internal.NettyTestHandler
import net.corda.testing.internal.NettyTestServer
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import java.net.InetAddress
import java.nio.file.Path
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import javax.security.auth.x500.X500Principal
import kotlin.test.assertEquals
import kotlin.test.assertTrue
@RunWith(Parameterized::class)
class NettyEngineBasedTlsAuthenticationTests(val sslSetup: SslSetup) {
@Rule
@JvmField
val tempFolder: TemporaryFolder = TemporaryFolder()
// Root CA.
private val ROOT_X500 = X500Principal("CN=Root_CA_1,O=R3CEV,L=London,C=GB")
// Intermediate CA.
private val INTERMEDIATE_X500 = X500Principal("CN=Intermediate_CA_1,O=R3CEV,L=London,C=GB")
// TLS server (server).
private val CLIENT_1_X500 = CordaX500Name(commonName = "Client_1", organisation = "R3CEV", locality = "London", country = "GB")
// TLS client (client).
private val CLIENT_2_X500 = CordaX500Name(commonName = "Client_2", organisation = "R3CEV", locality = "London", country = "GB")
// Password for keys and keystores.
private val PASSWORD = "dummypassword"
// Default supported TLS schemes for Corda nodes.
private val CORDA_TLS_CIPHER_SUITES = arrayOf(
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
)
private fun tempFile(name: String): Path = tempFolder.root.toPath() / name
companion object {
private val portAllocation = PortAllocation.Incremental(10000)
data class SslSetup(val clientNative: Boolean, val serverNative: Boolean) {
override fun toString(): String = "Client: ${if (clientNative) "openSsl" else "javaSsl"} Server: ${if (serverNative) "openSsl" else "javaSsl"} "
}
@JvmStatic
@Parameterized.Parameters(name = "{0}")
fun data(): Collection<SslSetup> = listOf(
SslSetup(false, false),
SslSetup(true, false),
SslSetup(false, true),
SslSetup(true, true)
)
}
@Test
fun `All EC R1`() {
val (serverContext, clientContext) = buildContexts(
rootCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
intermediateCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
serverCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
serverTLSScheme = Crypto.ECDSA_SECP256R1_SHA256,
clientCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
clientTLSScheme = Crypto.ECDSA_SECP256R1_SHA256
)
testConnect(serverContext, clientContext, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256")
}
@Test
fun `All RSA`() {
val (serverContext, clientContext) = buildContexts(
rootCAScheme = Crypto.RSA_SHA256,
intermediateCAScheme = Crypto.RSA_SHA256,
serverCAScheme = Crypto.RSA_SHA256,
serverTLSScheme = Crypto.RSA_SHA256,
clientCAScheme = Crypto.RSA_SHA256,
clientTLSScheme = Crypto.RSA_SHA256
)
testConnect(serverContext, clientContext, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256")
}
// Server's public key type is the one selected if users use different key types (e.g RSA and EC R1).
@Test
fun `Server RSA - Client EC R1 - CAs all EC R1`() {
val (serverContext, clientContext) = buildContexts(
rootCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
intermediateCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
serverCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
serverTLSScheme = Crypto.RSA_SHA256,
clientCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
clientTLSScheme = Crypto.ECDSA_SECP256R1_SHA256
)
testConnect(serverContext, clientContext, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256") // Server's key type is selected.
}
@Test
fun `Server EC R1 - Client RSA - CAs all EC R1`() {
val (serverContext, clientContext) = buildContexts(
rootCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
intermediateCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
serverCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
serverTLSScheme = Crypto.ECDSA_SECP256R1_SHA256,
clientCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
clientTLSScheme = Crypto.RSA_SHA256
)
testConnect(serverContext, clientContext, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256") // Server's key type is selected.
}
@Test
fun `Server EC R1 - Client EC R1 - CAs all RSA`() {
val (serverContext, clientContext) = buildContexts(
rootCAScheme = Crypto.RSA_SHA256,
intermediateCAScheme = Crypto.RSA_SHA256,
serverCAScheme = Crypto.RSA_SHA256,
serverTLSScheme = Crypto.ECDSA_SECP256R1_SHA256,
clientCAScheme = Crypto.RSA_SHA256,
clientTLSScheme = Crypto.ECDSA_SECP256R1_SHA256
)
testConnect(serverContext, clientContext, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256")
}
@Test
fun `Server EC R1 - Client RSA - Mixed CAs`() {
val (serverContext, clientContext) = buildContexts(
rootCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
intermediateCAScheme = Crypto.RSA_SHA256,
serverCAScheme = Crypto.RSA_SHA256,
serverTLSScheme = Crypto.ECDSA_SECP256R1_SHA256,
clientCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
clientTLSScheme = Crypto.RSA_SHA256
)
testConnect(serverContext, clientContext, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256")
}
// According to RFC 5246 (TLS 1.2), section 7.4.1.2 ClientHello cipher_suites:
// This is a list of the cryptographic options supported by the client, with the client's first preference first.
//
// However, the server is still free to ignore this order and pick what it thinks is best,
// see https://security.stackexchange.com/questions/121608 for more information.
@Test
fun `TLS cipher suite order matters - implementation dependent`() {
val (serverContext, clientContext) = buildContexts(
rootCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
intermediateCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
serverCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
serverTLSScheme = Crypto.ECDSA_SECP256R1_SHA256,
clientCAScheme = Crypto.ECDSA_SECP256R1_SHA256,
clientTLSScheme = Crypto.ECDSA_SECP256R1_SHA256,
cipherSuitesServer = arrayOf("TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256"), // GCM then CBC.
cipherSuitesClient = arrayOf("TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256") // CBC then GCM.
)
val expectedCipherSuite = if (sslSetup.clientNative || sslSetup.serverNative)
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" // server wins if boring ssl is involved
else
"TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256" // client wins in pure JRE SSL
testConnect(serverContext, clientContext, expectedCipherSuite)
}
private fun buildContexts(
rootCAScheme: SignatureScheme,
intermediateCAScheme: SignatureScheme,
serverCAScheme: SignatureScheme,
serverTLSScheme: SignatureScheme,
clientCAScheme: SignatureScheme,
clientTLSScheme: SignatureScheme,
cipherSuitesServer: Array<String> = CORDA_TLS_CIPHER_SUITES,
cipherSuitesClient: Array<String> = CORDA_TLS_CIPHER_SUITES
): Pair<SslContext, SslContext> {
val trustStorePath = tempFile("cordaTrustStore.jks")
val serverTLSKeyStorePath = tempFile("serversslkeystore.jks")
val clientTLSKeyStorePath = tempFile("clientsslkeystore.jks")
// ROOT CA key and cert.
val rootCAKeyPair = Crypto.generateKeyPair(rootCAScheme)
val rootCACert = X509Utilities.createSelfSignedCACertificate(ROOT_X500, rootCAKeyPair)
// Intermediate CA key and cert.
val intermediateCAKeyPair = Crypto.generateKeyPair(intermediateCAScheme)
val intermediateCACert = X509Utilities.createCertificate(
CertificateType.INTERMEDIATE_CA,
rootCACert,
rootCAKeyPair,
INTERMEDIATE_X500,
intermediateCAKeyPair.public
)
// Client 1 keys, certs and SSLKeyStore.
val serverCAKeyPair = Crypto.generateKeyPair(serverCAScheme)
val serverCACert = X509Utilities.createCertificate(
CertificateType.NODE_CA,
intermediateCACert,
intermediateCAKeyPair,
CLIENT_1_X500.x500Principal,
serverCAKeyPair.public
)
val serverTLSKeyPair = Crypto.generateKeyPair(serverTLSScheme)
val serverTLSCert = X509Utilities.createCertificate(
CertificateType.TLS,
serverCACert,
serverCAKeyPair,
CLIENT_1_X500.x500Principal,
serverTLSKeyPair.public
)
val serverTLSKeyStore = loadOrCreateKeyStore(serverTLSKeyStorePath, PASSWORD)
serverTLSKeyStore.addOrReplaceKey(
X509Utilities.CORDA_CLIENT_TLS,
serverTLSKeyPair.private,
PASSWORD.toCharArray(),
arrayOf(serverTLSCert, serverCACert, intermediateCACert, rootCACert))
// serverTLSKeyStore.save(serverTLSKeyStorePath, PASSWORD)
val serverTLSKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
serverTLSKeyManagerFactory.init(serverTLSKeyStore, PASSWORD.toCharArray())
// Client 2 keys, certs and SSLKeyStore.
val clientCAKeyPair = Crypto.generateKeyPair(clientCAScheme)
val clientCACert = X509Utilities.createCertificate(
CertificateType.NODE_CA,
intermediateCACert,
intermediateCAKeyPair,
CLIENT_2_X500.x500Principal,
clientCAKeyPair.public
)
val clientTLSKeyPair = Crypto.generateKeyPair(clientTLSScheme)
val clientTLSCert = X509Utilities.createCertificate(
CertificateType.TLS,
clientCACert,
clientCAKeyPair,
CLIENT_2_X500.x500Principal,
clientTLSKeyPair.public
)
val clientTLSKeyStore = loadOrCreateKeyStore(clientTLSKeyStorePath, PASSWORD)
clientTLSKeyStore.addOrReplaceKey(
X509Utilities.CORDA_CLIENT_TLS,
clientTLSKeyPair.private,
PASSWORD.toCharArray(),
arrayOf(clientTLSCert, clientCACert, intermediateCACert, rootCACert))
// clientTLSKeyStore.save(clientTLSKeyStorePath, PASSWORD)
val clientTLSKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
clientTLSKeyManagerFactory.init(clientTLSKeyStore, PASSWORD.toCharArray())
val trustStore = loadOrCreateKeyStore(trustStorePath, PASSWORD)
trustStore.addOrReplaceCertificate(X509Utilities.CORDA_ROOT_CA, rootCACert)
trustStore.addOrReplaceCertificate(X509Utilities.CORDA_INTERMEDIATE_CA, intermediateCACert)
// trustStore.save(trustStorePath, PASSWORD)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore)
return Pair(
SslContextBuilder
.forServer(serverTLSKeyManagerFactory)
.trustManager(trustManagerFactory)
.ciphers(cipherSuitesServer.toMutableList())
.clientAuth(ClientAuth.REQUIRE)
.protocols("TLSv1.2")
.sslProvider(if (sslSetup.serverNative) SslProvider.OPENSSL else SslProvider.JDK)
.build(),
SslContextBuilder
.forClient()
.keyManager(clientTLSKeyManagerFactory)
.trustManager(trustManagerFactory)
.ciphers(cipherSuitesClient.toMutableList())
.protocols("TLSv1.2")
.sslProvider(if (sslSetup.clientNative) SslProvider.OPENSSL else SslProvider.JDK)
.build()
)
}
private fun testConnect(serverContext: SslContext, clientContext: SslContext, expectedCipherSuite: String) {
val serverHandler = NettyTestHandler { ctx, msg -> ctx?.writeAndFlush(msg) }
val clientHandler = NettyTestHandler { _, msg -> assertEquals("Hello!", NettyTestHandler.readString(msg)) }
NettyTestServer(serverContext, serverHandler, portAllocation.nextPort()).use { server ->
server.start()
NettyTestClient(clientContext, InetAddress.getLocalHost().canonicalHostName, server.port, clientHandler).use { client ->
client.start()
clientHandler.writeString("Hello!")
val readCalled = clientHandler.waitForReadCalled()
clientHandler.rethrowIfFailed()
serverHandler.rethrowIfFailed()
assertEquals(1, serverHandler.readCalledCounter)
assertEquals(1, clientHandler.readCalledCounter)
assertTrue(readCalled)
assertEquals(expectedCipherSuite, client.engine!!.session.cipherSuite)
}
}
}
}

View File

@ -0,0 +1,99 @@
package net.corda.testing.internal
import io.netty.bootstrap.Bootstrap
import io.netty.channel.ChannelFuture
import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.handler.ssl.SslContext
import io.netty.channel.ChannelInitializer
import io.netty.channel.ChannelOption
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.handler.ssl.SslHandler
import java.io.Closeable
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.SSLEngine
import kotlin.concurrent.thread
class NettyTestClient(
val sslContext: SslContext?,
val targetHost: String,
val targetPort: Int,
val handler: ChannelInboundHandlerAdapter
) : Closeable {
internal var mainThread: Thread? = null
internal var channelFuture: ChannelFuture? = null
// lock/condition to make sure that start only returns when the server is actually running
private val lock = ReentrantLock()
private val condition = lock.newCondition()
var engine: SSLEngine? = null
private set
fun start() {
try {
lock.lock()
mainThread = thread(start = true) { run() }
if (!condition.await(5, TimeUnit.SECONDS)) {
throw TimeoutException("Netty test server failed to start")
}
} finally {
lock.unlock()
}
}
private fun run() {
// Configure the client.
val group = NioEventLoopGroup()
try {
val b = Bootstrap()
b.group(group)
.channel(NioSocketChannel::class.java)
.option(ChannelOption.TCP_NODELAY, true)
.handler(object : ChannelInitializer<SocketChannel>() {
@Throws(Exception::class)
public override fun initChannel(ch: SocketChannel) {
val p = ch.pipeline()
if (sslContext != null) {
engine = sslContext.newEngine(ch.alloc(), targetHost, targetPort)
p.addLast(SslHandler(engine))
}
//p.addLast(new LoggingHandler(LogLevel.INFO));
p.addLast(handler)
}
})
// Start the client.
val f = b.connect(targetHost, targetPort)
try {
lock.lock()
condition.signal()
channelFuture = f.sync()
} finally {
lock.unlock()
}
// Wait until the connection is closed.
f.channel().closeFuture().sync()
} finally {
// Shut down the event loop to terminate all threads.
group.shutdownGracefully()
}
}
fun stop() {
channelFuture?.channel()?.close()
mainThread?.join()
mainThread = null
channelFuture = null
}
override fun close() {
stop()
}
}

View File

@ -0,0 +1,74 @@
package net.corda.testing.internal
import io.netty.buffer.ByteBuf
import io.netty.buffer.Unpooled
import io.netty.channel.Channel
import io.netty.channel.ChannelDuplexHandler
import io.netty.channel.ChannelHandlerContext
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
class NettyTestHandler(val onMessageFunc: (ctx: ChannelHandlerContext?, msg: Any?) -> Unit = { _, _ -> }) : ChannelDuplexHandler() {
private var channel: Channel? = null
private var failure: Throwable? = null
private val lock = ReentrantLock()
private val condition = lock.newCondition()
var readCalledCounter: Int = 0
private set
override fun channelRegistered(ctx: ChannelHandlerContext?) {
channel = ctx?.channel()
super.channelRegistered(ctx)
}
override fun channelRead(ctx: ChannelHandlerContext?, msg: Any?) {
try {
lock.lock()
readCalledCounter++
onMessageFunc(ctx, msg)
} catch( e: Throwable ){
failure = e
} finally {
condition.signal()
lock.unlock()
}
}
fun writeString(msg: String) {
val buffer = Unpooled.wrappedBuffer(msg.toByteArray())
require(channel != null) { "Channel must be registered before sending messages" }
channel!!.writeAndFlush(buffer)
}
fun rethrowIfFailed() {
failure?.also { throw it }
}
fun waitForReadCalled(numberOfExpectedCalls: Int = 1): Boolean {
try {
lock.lock()
if (readCalledCounter >= numberOfExpectedCalls) {
return true
}
while (readCalledCounter < numberOfExpectedCalls) {
if (!condition.await(5, TimeUnit.SECONDS)) {
return false
}
}
return true
} finally {
lock.unlock()
}
}
companion object {
fun readString(buffer: Any?): String {
checkNotNull(buffer)
val ar = ByteArray((buffer as ByteBuf).readableBytes())
buffer.readBytes(ar)
return String(ar)
}
}
}

View File

@ -0,0 +1,97 @@
package net.corda.testing.internal
import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.ChannelFuture
import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.channel.ChannelInitializer
import io.netty.channel.ChannelOption
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.logging.LogLevel
import io.netty.handler.logging.LoggingHandler
import io.netty.handler.ssl.SslContext
import java.io.Closeable
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeoutException
import java.util.concurrent.locks.Condition
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.thread
class NettyTestServer(
private val sslContext: SslContext?,
val messageHandler: ChannelInboundHandlerAdapter,
val port: Int
) : Closeable {
internal var mainThread: Thread? = null
internal var channel: ChannelFuture? = null
// lock/condition to make sure that start only returns when the server is actually running
val lock = ReentrantLock()
val condition: Condition = lock.newCondition()
fun start() {
try {
lock.lock()
mainThread = thread(start = true) { run() }
if (!condition.await(5, TimeUnit.SECONDS)) {
throw TimeoutException("Netty test server failed to start")
}
} finally {
lock.unlock()
}
}
fun run() {
// Configure the server.
val bossGroup = NioEventLoopGroup(1)
val workerGroup = NioEventLoopGroup()
try {
val b = ServerBootstrap()
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel::class.java)
.option(ChannelOption.SO_BACKLOG, 100)
.handler(LoggingHandler(LogLevel.INFO))
.childHandler(object : ChannelInitializer<SocketChannel>() {
@Throws(Exception::class)
public override fun initChannel(ch: SocketChannel) {
val p = ch.pipeline()
if (sslContext != null) {
p.addLast(sslContext.newHandler(ch.alloc()))
}
//p.addLast(new LoggingHandler(LogLevel.INFO));
p.addLast(messageHandler)
}
})
// Start the server.
val f = b.bind(port)
try {
lock.lock()
channel = f.sync()
condition.signal()
} finally {
lock.unlock()
}
// Wait until the server socket is closed.
channel!!.channel().closeFuture().sync()
} finally {
// Shut down all event loops to terminate all threads.
bossGroup.shutdownGracefully()
workerGroup.shutdownGracefully()
}
}
fun stop() {
channel?.channel()?.close()
mainThread?.join()
channel = null
mainThread = null
}
override fun close() {
stop()
}
}

View File

@ -42,11 +42,11 @@ class CertificateStoreStubs {
companion object {
@JvmStatic
fun withCertificatesDirectory(certificatesDirectory: Path, keyStoreFileName: String = KeyStore.DEFAULT_STORE_FILE_NAME, keyStorePassword: String = KeyStore.DEFAULT_STORE_PASSWORD, trustStoreFileName: String = TrustStore.DEFAULT_STORE_FILE_NAME, trustStorePassword: String = TrustStore.DEFAULT_STORE_PASSWORD): MutualSslConfiguration {
fun withCertificatesDirectory(certificatesDirectory: Path, keyStoreFileName: String = KeyStore.DEFAULT_STORE_FILE_NAME, keyStorePassword: String = KeyStore.DEFAULT_STORE_PASSWORD, trustStoreFileName: String = TrustStore.DEFAULT_STORE_FILE_NAME, trustStorePassword: String = TrustStore.DEFAULT_STORE_PASSWORD, useOpenSsl: Boolean = false): MutualSslConfiguration {
val keyStore = FileBasedCertificateStoreSupplier(certificatesDirectory / keyStoreFileName, keyStorePassword)
val trustStore = FileBasedCertificateStoreSupplier(certificatesDirectory / trustStoreFileName, trustStorePassword)
return SslConfiguration.mutual(keyStore, trustStore)
return SslConfiguration.mutual(keyStore, trustStore, useOpenSsl)
}
@JvmStatic

View File

@ -0,0 +1,88 @@
package net.corda.testing.internal
import io.netty.channel.ChannelInboundHandlerAdapter
import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
class TestNettyTestInfra {
@Test
fun testStartAndStopServer() {
val testHandler = rigorousMock<ChannelInboundHandlerAdapter>()
NettyTestServer(null, testHandler, 56234).use { server ->
server.start()
assertNotNull(server.mainThread)
assertNotNull(server.channel)
}
}
@Test
fun testStartAndStopClient() {
val serverHandler = ChannelInboundHandlerAdapter()
val clientHandler = ChannelInboundHandlerAdapter()
NettyTestServer(null, serverHandler, 56234).use { server ->
server.start()
NettyTestClient(null, "localhost", 56234, clientHandler).use { client ->
client.start()
assertNotNull(client.mainThread)
assertNotNull(client.channelFuture)
}
}
}
@Test
fun testPingPong() {
val serverHandler = NettyTestHandler { ctx, msg ->
ctx?.writeAndFlush(msg)
}
val clientHandler = NettyTestHandler { _, msg ->
assertEquals("ping", NettyTestHandler.readString(msg))
}
NettyTestServer(null, serverHandler, 56234).use { server ->
server.start()
NettyTestClient(null, "localhost", 56234, clientHandler).use { client ->
client.start()
clientHandler.writeString("ping")
assertTrue(clientHandler.waitForReadCalled(1))
clientHandler.rethrowIfFailed()
assertEquals(1, clientHandler.readCalledCounter)
assertEquals(1, serverHandler.readCalledCounter)
}
}
}
@Test
fun testFailureHandling() {
val serverHandler = NettyTestHandler { ctx, msg ->
ctx?.writeAndFlush(msg)
}
val clientHandler = NettyTestHandler { _, msg ->
assertEquals("pong", NettyTestHandler.readString(msg))
}
NettyTestServer(null, serverHandler, 56234).use { server ->
server.start()
NettyTestClient(null, "localhost", 56234, clientHandler).use { client ->
client.start()
clientHandler.writeString("ping")
assertTrue(clientHandler.waitForReadCalled(1))
var exceptionThrown = false
try {
clientHandler.rethrowIfFailed()
} catch (e: AssertionError) {
exceptionThrown = true
}
assertTrue(exceptionThrown, "Expected assertion failure has not been thrown")
assertEquals(1, serverHandler.readCalledCounter)
assertEquals(1, clientHandler.readCalledCounter)
}
}
}
}