diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt index ae83f53455..4663a038a9 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt @@ -83,24 +83,40 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, val sslHandler = ctx.pipeline().get(SslHandler::class.java) localCert = sslHandler.engine().session.localCertificates[0].x509 remoteCert = sslHandler.engine().session.peerCertificates[0].x509 - try { - val remoteX500Name = CordaX500Name.build(remoteCert!!.subjectX500Principal) - require(allowedRemoteLegalNames == null || remoteX500Name in allowedRemoteLegalNames) - log.info("handshake completed subject: $remoteX500Name") + val remoteX500Name = try { + CordaX500Name.build(remoteCert!!.subjectX500Principal) } catch (ex: IllegalArgumentException) { - log.error("Invalid certificate subject", ex) + log.error("Certificate subject not a valid CordaX500Name", ex) ctx.close() return } + if (allowedRemoteLegalNames != null && remoteX500Name !in allowedRemoteLegalNames) { + log.error("Provided certificate subject $remoteX500Name not in expected set $allowedRemoteLegalNames") + ctx.close() + return + } + log.info("Handshake completed with subject: $remoteX500Name") createAMQPEngine(ctx) onOpen(Pair(ctx.channel() as SocketChannel, ConnectionChange(remoteAddress, remoteCert, true))) } else { - log.error("Handshake failure $evt") + log.error("Handshake failure ${evt.cause().message}") + if (log.isTraceEnabled) { + log.trace("Handshake failure", evt.cause()) + } ctx.close() } } } + @Suppress("OverridingDeprecatedMember") + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + log.warn("Closing channel due to nonrecoverable exception ${cause.message}") + if (log.isTraceEnabled) { + log.trace("Pipeline uncaught exception", cause) + } + ctx.close() + } + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { try { if (msg is ByteBuf) { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt index 3c7bb32db2..7313c1e139 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt @@ -17,6 +17,7 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import rx.Observable import rx.subjects.PublishSubject +import java.lang.Long.min import java.security.KeyStore import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock @@ -47,7 +48,9 @@ class AMQPClient(val targets: List, } val log = contextLogger() - const val RETRY_INTERVAL = 1000L + const val MIN_RETRY_INTERVAL = 1000L + const val MAX_RETRY_INTERVAL = 60000L + const val BACKOFF_MULTIPLIER = 2L const val NUM_CLIENT_THREADS = 2 } @@ -60,6 +63,13 @@ class AMQPClient(val targets: List, // Offset into the list of targets, so that we can implement round-robin reconnect logic. private var targetIndex = 0 private var currentTarget: NetworkHostAndPort = targets.first() + private var retryInterval = MIN_RETRY_INTERVAL + + private fun nextTarget() { + targetIndex = (targetIndex + 1).rem(targets.size) + log.info("Retry connect to ${targets[targetIndex]}") + retryInterval = min(MAX_RETRY_INTERVAL, retryInterval * BACKOFF_MULTIPLIER) + } private val connectListener = object : ChannelFutureListener { override fun operationComplete(future: ChannelFuture) { @@ -68,10 +78,9 @@ class AMQPClient(val targets: List, if (!stopping) { workerGroup?.schedule({ - log.info("Retry connect to $currentTarget") - targetIndex = (targetIndex + 1).rem(targets.size) + nextTarget() restart() - }, RETRY_INTERVAL, TimeUnit.MILLISECONDS) + }, retryInterval, TimeUnit.MILLISECONDS) } } else { log.info("Connected to $currentTarget") @@ -89,10 +98,9 @@ class AMQPClient(val targets: List, clientChannel = null if (!stopping) { workerGroup?.schedule({ - log.info("Retry connect") - targetIndex = (targetIndex + 1).rem(targets.size) + nextTarget() restart() - }, RETRY_INTERVAL, TimeUnit.MILLISECONDS) + }, retryInterval, TimeUnit.MILLISECONDS) } } } @@ -116,7 +124,10 @@ class AMQPClient(val targets: List, parent.userName, parent.password, parent.trace, - { parent._onConnection.onNext(it.second) }, + { + parent.retryInterval = MIN_RETRY_INTERVAL // reset to fast reconnect if we connect properly + parent._onConnection.onNext(it.second) + }, { parent._onConnection.onNext(it.second) }, { rcv -> parent._onReceive.onNext(rcv) })) } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt index 689211fc30..0f799a8164 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt @@ -2,22 +2,111 @@ package net.corda.nodeapi.internal.protonwrapper.netty import io.netty.handler.ssl.SslHandler import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.toHex import net.corda.nodeapi.ArtemisTcpTransport +import net.corda.nodeapi.internal.crypto.toBc +import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier +import org.bouncycastle.asn1.x509.Extension +import org.bouncycastle.asn1.x509.SubjectKeyIdentifier +import java.net.Socket import java.security.KeyStore import java.security.SecureRandom -import java.security.cert.CertPathBuilder -import java.security.cert.PKIXBuilderParameters -import java.security.cert.PKIXRevocationChecker -import java.security.cert.X509CertSelector +import java.security.cert.* import java.util.* import javax.net.ssl.* +internal class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() { + companion object { + val log = contextLogger() + } + + private fun certPathToString(certPath: Array?): String { + if (certPath == null) { + return "" + } + val certs = certPath.map { + val bcCert = it.toBc() + val subject = bcCert.subject.toString() + val issuer = bcCert.issuer.toString() + val keyIdentifier = try { + SubjectKeyIdentifier.getInstance(bcCert.getExtension(Extension.subjectKeyIdentifier).parsedValue).keyIdentifier.toHex() + } catch (ex: Exception) { + "null" + } + val authorityKeyIdentifier = try { + AuthorityKeyIdentifier.getInstance(bcCert.getExtension(Extension.authorityKeyIdentifier).parsedValue).keyIdentifier.toHex() + } catch (ex: Exception) { + "null" + } + " $subject[$keyIdentifier] issued by $issuer[$authorityKeyIdentifier]" + } + return certs.joinToString("\r\n") + } + + + private fun certPathToStringFull(chain: Array?): String { + if (chain == null) { + return "" + } + return chain.map { it.toString() }.joinToString(", ") + } + + private fun logErrors(chain: Array?, block: () -> Unit) { + try { + block() + } catch (ex: CertificateException) { + log.error("Bad certificate path ${ex.message}:\r\n${certPathToStringFull(chain)}") + throw ex + } + } + + @Throws(CertificateException::class) + override fun checkClientTrusted(chain: Array?, authType: String?, socket: Socket?) { + log.info("Check Client Certpath:\r\n${certPathToString(chain)}") + logErrors(chain) { wrapped.checkClientTrusted(chain, authType, socket) } + } + + @Throws(CertificateException::class) + override fun checkClientTrusted(chain: Array?, authType: String?, engine: SSLEngine?) { + log.info("Check Client Certpath:\r\n${certPathToString(chain)}") + logErrors(chain) { wrapped.checkClientTrusted(chain, authType, engine) } + } + + @Throws(CertificateException::class) + override fun checkClientTrusted(chain: Array?, authType: String?) { + log.info("Check Client Certpath:\r\n${certPathToString(chain)}") + logErrors(chain) { wrapped.checkClientTrusted(chain, authType) } + } + + @Throws(CertificateException::class) + override fun checkServerTrusted(chain: Array?, authType: String?, socket: Socket?) { + log.info("Check Server Certpath:\r\n${certPathToString(chain)}") + logErrors(chain) { wrapped.checkServerTrusted(chain, authType, socket) } + } + + @Throws(CertificateException::class) + override fun checkServerTrusted(chain: Array?, authType: String?, engine: SSLEngine?) { + log.info("Check Server Certpath:\r\n${certPathToString(chain)}") + logErrors(chain) { wrapped.checkServerTrusted(chain, authType, engine) } + } + + @Throws(CertificateException::class) + override fun checkServerTrusted(chain: Array?, authType: String?) { + log.info("Check Server Certpath:\r\n${certPathToString(chain)}") + logErrors(chain) { wrapped.checkServerTrusted(chain, authType) } + } + + override fun getAcceptedIssuers(): Array = wrapped.acceptedIssuers + +} + internal fun createClientSslHelper(target: NetworkHostAndPort, keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslHandler { val sslContext = SSLContext.getInstance("TLS") val keyManagers = keyManagerFactory.keyManagers - val trustManagers = trustManagerFactory.trustManagers + val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java).map { LoggingTrustManagerWrapper(it) }.toTypedArray() sslContext.init(keyManagers, trustManagers, SecureRandom()) val sslEngine = sslContext.createSSLEngine(target.host, target.port) sslEngine.useClientMode = true @@ -31,7 +120,7 @@ internal fun createServerSslHelper(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslHandler { val sslContext = SSLContext.getInstance("TLS") val keyManagers = keyManagerFactory.keyManagers - val trustManagers = trustManagerFactory.trustManagers + val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java).map { LoggingTrustManagerWrapper(it) }.toTypedArray() sslContext.init(keyManagers, trustManagers, SecureRandom()) val sslEngine = sslContext.createSSLEngine() sslEngine.useClientMode = false diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt index 1090aef766..40142f1dcf 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt @@ -12,20 +12,30 @@ import net.corda.node.services.config.CertChainPolicyConfig import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.node.services.messaging.ArtemisMessagingServer +import net.corda.nodeapi.ArtemisTcpTransport.Companion.CIPHER_SUITES import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEER_USER +import net.corda.nodeapi.internal.config.SSLConfiguration +import net.corda.nodeapi.internal.createDevKeyStores +import net.corda.nodeapi.internal.crypto.* import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer import net.corda.testing.core.* +import net.corda.testing.internal.createDevIntermediateCaCertPath import net.corda.testing.internal.rigorousMock import org.apache.activemq.artemis.api.core.RoutingType import org.junit.Assert.assertArrayEquals import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder +import java.security.SecureRandom +import java.security.cert.X509Certificate +import javax.net.ssl.* +import kotlin.concurrent.thread import kotlin.test.assertEquals +import kotlin.test.assertTrue class ProtonWrapperTests { @Rule @@ -86,6 +96,91 @@ class ProtonWrapperTests { } } + private fun SSLConfiguration.createTrustStore(rootCert: X509Certificate) { + val trustStore = loadOrCreateKeyStore(trustStoreFile, trustStorePassword) + trustStore.addOrReplaceCertificate(X509Utilities.CORDA_ROOT_CA, rootCert) + trustStore.save(trustStoreFile, trustStorePassword) + } + + + @Test + fun `Test AMQP Client with invalid root certificate`() { + val sslConfig = object : SSLConfiguration { + override val certificatesDirectory = temporaryFolder.root.toPath() + override val keyStorePassword = "serverstorepass" + override val trustStorePassword = "trustpass" + override val crlCheckSoftFail: Boolean = true + } + + val (rootCa, intermediateCa) = createDevIntermediateCaCertPath() + + // Generate server cert and private key and populate another keystore suitable for SSL + sslConfig.createDevKeyStores(ALICE_NAME, rootCa.certificate, intermediateCa) + sslConfig.createTrustStore(rootCa.certificate) + + val keyStore = loadKeyStore(sslConfig.sslKeystore, sslConfig.keyStorePassword) + val trustStore = loadKeyStore(sslConfig.trustStoreFile, sslConfig.trustStorePassword) + + val context = SSLContext.getInstance("TLS") + val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) + keyManagerFactory.init(keyStore, sslConfig.keyStorePassword.toCharArray()) + val keyManagers = keyManagerFactory.keyManagers + val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + trustMgrFactory.init(trustStore) + val trustManagers = trustMgrFactory.trustManagers + context.init(keyManagers, trustManagers, SecureRandom()) + + val serverSocketFactory = context.serverSocketFactory + + val serverSocket = serverSocketFactory.createServerSocket(serverPort) as SSLServerSocket + val serverParams = SSLParameters(CIPHER_SUITES.toTypedArray(), + arrayOf("TLSv1.2")) + serverParams.wantClientAuth = true + serverParams.needClientAuth = true + serverParams.endpointIdentificationAlgorithm = null // Reconfirm default no server name indication, use our own validator. + serverSocket.sslParameters = serverParams + serverSocket.useClientMode = false + + val lock = Object() + var done = false + var handshakeError = false + + val serverThread = thread { + try { + val sslServerSocket = serverSocket.accept() as SSLSocket + sslServerSocket.addHandshakeCompletedListener { + done = true + } + sslServerSocket.startHandshake() + synchronized(lock) { + while (!done) { + lock.wait(1000) + } + } + sslServerSocket.close() + } catch (ex: SSLHandshakeException) { + handshakeError = true + } + } + + val amqpClient = createClient() + amqpClient.use { + val clientConnected = amqpClient.onConnection.toFuture() + amqpClient.start() + val clientConnect = clientConnected.get() + assertEquals(false, clientConnect.connected) + synchronized(lock) { + done = true + lock.notifyAll() + } + } + serverThread.join(1000) + assertTrue(handshakeError) + serverSocket.close() + assertTrue(done) + } + + @Test fun `Client Failover for multiple IP`() { val amqpServer = createServer(serverPort)