ENT-1850: Improve reporting of connection problems (#3124)

* Add nicer logging for SSL handshake problems

* Just in case let people see the horrid netty exception traces at trace level
This commit is contained in:
Matthew Nesbit 2018-05-14 09:14:09 +01:00 committed by GitHub
parent e47a84ab49
commit 3c005789c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 231 additions and 20 deletions

View File

@ -83,24 +83,40 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
val sslHandler = ctx.pipeline().get(SslHandler::class.java) val sslHandler = ctx.pipeline().get(SslHandler::class.java)
localCert = sslHandler.engine().session.localCertificates[0].x509 localCert = sslHandler.engine().session.localCertificates[0].x509
remoteCert = sslHandler.engine().session.peerCertificates[0].x509 remoteCert = sslHandler.engine().session.peerCertificates[0].x509
try { val remoteX500Name = try {
val remoteX500Name = CordaX500Name.build(remoteCert!!.subjectX500Principal) CordaX500Name.build(remoteCert!!.subjectX500Principal)
require(allowedRemoteLegalNames == null || remoteX500Name in allowedRemoteLegalNames)
log.info("handshake completed subject: $remoteX500Name")
} catch (ex: IllegalArgumentException) { } catch (ex: IllegalArgumentException) {
log.error("Invalid certificate subject", ex) log.error("Certificate subject not a valid CordaX500Name", ex)
ctx.close() ctx.close()
return return
} }
if (allowedRemoteLegalNames != null && remoteX500Name !in allowedRemoteLegalNames) {
log.error("Provided certificate subject $remoteX500Name not in expected set $allowedRemoteLegalNames")
ctx.close()
return
}
log.info("Handshake completed with subject: $remoteX500Name")
createAMQPEngine(ctx) createAMQPEngine(ctx)
onOpen(Pair(ctx.channel() as SocketChannel, ConnectionChange(remoteAddress, remoteCert, true))) onOpen(Pair(ctx.channel() as SocketChannel, ConnectionChange(remoteAddress, remoteCert, true)))
} else { } else {
log.error("Handshake failure $evt") log.error("Handshake failure ${evt.cause().message}")
if (log.isTraceEnabled) {
log.trace("Handshake failure", evt.cause())
}
ctx.close() ctx.close()
} }
} }
} }
@Suppress("OverridingDeprecatedMember")
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
log.warn("Closing channel due to nonrecoverable exception ${cause.message}")
if (log.isTraceEnabled) {
log.trace("Pipeline uncaught exception", cause)
}
ctx.close()
}
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
try { try {
if (msg is ByteBuf) { if (msg is ByteBuf) {

View File

@ -17,6 +17,7 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.lang.Long.min
import java.security.KeyStore import java.security.KeyStore
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
@ -47,7 +48,9 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
} }
val log = contextLogger() val log = contextLogger()
const val RETRY_INTERVAL = 1000L const val MIN_RETRY_INTERVAL = 1000L
const val MAX_RETRY_INTERVAL = 60000L
const val BACKOFF_MULTIPLIER = 2L
const val NUM_CLIENT_THREADS = 2 const val NUM_CLIENT_THREADS = 2
} }
@ -60,6 +63,13 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
// Offset into the list of targets, so that we can implement round-robin reconnect logic. // Offset into the list of targets, so that we can implement round-robin reconnect logic.
private var targetIndex = 0 private var targetIndex = 0
private var currentTarget: NetworkHostAndPort = targets.first() private var currentTarget: NetworkHostAndPort = targets.first()
private var retryInterval = MIN_RETRY_INTERVAL
private fun nextTarget() {
targetIndex = (targetIndex + 1).rem(targets.size)
log.info("Retry connect to ${targets[targetIndex]}")
retryInterval = min(MAX_RETRY_INTERVAL, retryInterval * BACKOFF_MULTIPLIER)
}
private val connectListener = object : ChannelFutureListener { private val connectListener = object : ChannelFutureListener {
override fun operationComplete(future: ChannelFuture) { override fun operationComplete(future: ChannelFuture) {
@ -68,10 +78,9 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
if (!stopping) { if (!stopping) {
workerGroup?.schedule({ workerGroup?.schedule({
log.info("Retry connect to $currentTarget") nextTarget()
targetIndex = (targetIndex + 1).rem(targets.size)
restart() restart()
}, RETRY_INTERVAL, TimeUnit.MILLISECONDS) }, retryInterval, TimeUnit.MILLISECONDS)
} }
} else { } else {
log.info("Connected to $currentTarget") log.info("Connected to $currentTarget")
@ -89,10 +98,9 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
clientChannel = null clientChannel = null
if (!stopping) { if (!stopping) {
workerGroup?.schedule({ workerGroup?.schedule({
log.info("Retry connect") nextTarget()
targetIndex = (targetIndex + 1).rem(targets.size)
restart() restart()
}, RETRY_INTERVAL, TimeUnit.MILLISECONDS) }, retryInterval, TimeUnit.MILLISECONDS)
} }
} }
} }
@ -116,7 +124,10 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
parent.userName, parent.userName,
parent.password, parent.password,
parent.trace, parent.trace,
{ parent._onConnection.onNext(it.second) }, {
parent.retryInterval = MIN_RETRY_INTERVAL // reset to fast reconnect if we connect properly
parent._onConnection.onNext(it.second)
},
{ parent._onConnection.onNext(it.second) }, { parent._onConnection.onNext(it.second) },
{ rcv -> parent._onReceive.onNext(rcv) })) { rcv -> parent._onReceive.onNext(rcv) }))
} }

View File

@ -2,22 +2,111 @@ package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.handler.ssl.SslHandler import io.netty.handler.ssl.SslHandler
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.toHex
import net.corda.nodeapi.ArtemisTcpTransport import net.corda.nodeapi.ArtemisTcpTransport
import net.corda.nodeapi.internal.crypto.toBc
import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import java.net.Socket
import java.security.KeyStore import java.security.KeyStore
import java.security.SecureRandom import java.security.SecureRandom
import java.security.cert.CertPathBuilder import java.security.cert.*
import java.security.cert.PKIXBuilderParameters
import java.security.cert.PKIXRevocationChecker
import java.security.cert.X509CertSelector
import java.util.* import java.util.*
import javax.net.ssl.* import javax.net.ssl.*
internal class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() {
companion object {
val log = contextLogger()
}
private fun certPathToString(certPath: Array<out X509Certificate>?): String {
if (certPath == null) {
return "<empty certpath>"
}
val certs = certPath.map {
val bcCert = it.toBc()
val subject = bcCert.subject.toString()
val issuer = bcCert.issuer.toString()
val keyIdentifier = try {
SubjectKeyIdentifier.getInstance(bcCert.getExtension(Extension.subjectKeyIdentifier).parsedValue).keyIdentifier.toHex()
} catch (ex: Exception) {
"null"
}
val authorityKeyIdentifier = try {
AuthorityKeyIdentifier.getInstance(bcCert.getExtension(Extension.authorityKeyIdentifier).parsedValue).keyIdentifier.toHex()
} catch (ex: Exception) {
"null"
}
" $subject[$keyIdentifier] issued by $issuer[$authorityKeyIdentifier]"
}
return certs.joinToString("\r\n")
}
private fun certPathToStringFull(chain: Array<out X509Certificate>?): String {
if (chain == null) {
return "<empty certpath>"
}
return chain.map { it.toString() }.joinToString(", ")
}
private fun logErrors(chain: Array<out X509Certificate>?, block: () -> Unit) {
try {
block()
} catch (ex: CertificateException) {
log.error("Bad certificate path ${ex.message}:\r\n${certPathToStringFull(chain)}")
throw ex
}
}
@Throws(CertificateException::class)
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?, socket: Socket?) {
log.info("Check Client Certpath:\r\n${certPathToString(chain)}")
logErrors(chain) { wrapped.checkClientTrusted(chain, authType, socket) }
}
@Throws(CertificateException::class)
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?, engine: SSLEngine?) {
log.info("Check Client Certpath:\r\n${certPathToString(chain)}")
logErrors(chain) { wrapped.checkClientTrusted(chain, authType, engine) }
}
@Throws(CertificateException::class)
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {
log.info("Check Client Certpath:\r\n${certPathToString(chain)}")
logErrors(chain) { wrapped.checkClientTrusted(chain, authType) }
}
@Throws(CertificateException::class)
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?, socket: Socket?) {
log.info("Check Server Certpath:\r\n${certPathToString(chain)}")
logErrors(chain) { wrapped.checkServerTrusted(chain, authType, socket) }
}
@Throws(CertificateException::class)
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?, engine: SSLEngine?) {
log.info("Check Server Certpath:\r\n${certPathToString(chain)}")
logErrors(chain) { wrapped.checkServerTrusted(chain, authType, engine) }
}
@Throws(CertificateException::class)
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {
log.info("Check Server Certpath:\r\n${certPathToString(chain)}")
logErrors(chain) { wrapped.checkServerTrusted(chain, authType) }
}
override fun getAcceptedIssuers(): Array<X509Certificate> = wrapped.acceptedIssuers
}
internal fun createClientSslHelper(target: NetworkHostAndPort, internal fun createClientSslHelper(target: NetworkHostAndPort,
keyManagerFactory: KeyManagerFactory, keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler { trustManagerFactory: TrustManagerFactory): SslHandler {
val sslContext = SSLContext.getInstance("TLS") val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java).map { LoggingTrustManagerWrapper(it) }.toTypedArray()
sslContext.init(keyManagers, trustManagers, SecureRandom()) sslContext.init(keyManagers, trustManagers, SecureRandom())
val sslEngine = sslContext.createSSLEngine(target.host, target.port) val sslEngine = sslContext.createSSLEngine(target.host, target.port)
sslEngine.useClientMode = true sslEngine.useClientMode = true
@ -31,7 +120,7 @@ internal fun createServerSslHelper(keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler { trustManagerFactory: TrustManagerFactory): SslHandler {
val sslContext = SSLContext.getInstance("TLS") val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java).map { LoggingTrustManagerWrapper(it) }.toTypedArray()
sslContext.init(keyManagers, trustManagers, SecureRandom()) sslContext.init(keyManagers, trustManagers, SecureRandom())
val sslEngine = sslContext.createSSLEngine() val sslEngine = sslContext.createSSLEngine()
sslEngine.useClientMode = false sslEngine.useClientMode = false

View File

@ -12,20 +12,30 @@ import net.corda.node.services.config.CertChainPolicyConfig
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.node.services.messaging.ArtemisMessagingServer import net.corda.node.services.messaging.ArtemisMessagingServer
import net.corda.nodeapi.ArtemisTcpTransport.Companion.CIPHER_SUITES
import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEER_USER import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEER_USER
import net.corda.nodeapi.internal.config.SSLConfiguration
import net.corda.nodeapi.internal.createDevKeyStores
import net.corda.nodeapi.internal.crypto.*
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.testing.core.* import net.corda.testing.core.*
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.rigorousMock
import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.RoutingType
import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertArrayEquals
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.rules.TemporaryFolder import org.junit.rules.TemporaryFolder
import java.security.SecureRandom
import java.security.cert.X509Certificate
import javax.net.ssl.*
import kotlin.concurrent.thread
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue
class ProtonWrapperTests { class ProtonWrapperTests {
@Rule @Rule
@ -86,6 +96,91 @@ class ProtonWrapperTests {
} }
} }
private fun SSLConfiguration.createTrustStore(rootCert: X509Certificate) {
val trustStore = loadOrCreateKeyStore(trustStoreFile, trustStorePassword)
trustStore.addOrReplaceCertificate(X509Utilities.CORDA_ROOT_CA, rootCert)
trustStore.save(trustStoreFile, trustStorePassword)
}
@Test
fun `Test AMQP Client with invalid root certificate`() {
val sslConfig = object : SSLConfiguration {
override val certificatesDirectory = temporaryFolder.root.toPath()
override val keyStorePassword = "serverstorepass"
override val trustStorePassword = "trustpass"
override val crlCheckSoftFail: Boolean = true
}
val (rootCa, intermediateCa) = createDevIntermediateCaCertPath()
// Generate server cert and private key and populate another keystore suitable for SSL
sslConfig.createDevKeyStores(ALICE_NAME, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate)
val keyStore = loadKeyStore(sslConfig.sslKeystore, sslConfig.keyStorePassword)
val trustStore = loadKeyStore(sslConfig.trustStoreFile, sslConfig.trustStorePassword)
val context = SSLContext.getInstance("TLS")
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore, sslConfig.keyStorePassword.toCharArray())
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustManagers = trustMgrFactory.trustManagers
context.init(keyManagers, trustManagers, SecureRandom())
val serverSocketFactory = context.serverSocketFactory
val serverSocket = serverSocketFactory.createServerSocket(serverPort) as SSLServerSocket
val serverParams = SSLParameters(CIPHER_SUITES.toTypedArray(),
arrayOf("TLSv1.2"))
serverParams.wantClientAuth = true
serverParams.needClientAuth = true
serverParams.endpointIdentificationAlgorithm = null // Reconfirm default no server name indication, use our own validator.
serverSocket.sslParameters = serverParams
serverSocket.useClientMode = false
val lock = Object()
var done = false
var handshakeError = false
val serverThread = thread {
try {
val sslServerSocket = serverSocket.accept() as SSLSocket
sslServerSocket.addHandshakeCompletedListener {
done = true
}
sslServerSocket.startHandshake()
synchronized(lock) {
while (!done) {
lock.wait(1000)
}
}
sslServerSocket.close()
} catch (ex: SSLHandshakeException) {
handshakeError = true
}
}
val amqpClient = createClient()
amqpClient.use {
val clientConnected = amqpClient.onConnection.toFuture()
amqpClient.start()
val clientConnect = clientConnected.get()
assertEquals(false, clientConnect.connected)
synchronized(lock) {
done = true
lock.notifyAll()
}
}
serverThread.join(1000)
assertTrue(handshakeError)
serverSocket.close()
assertTrue(done)
}
@Test @Test
fun `Client Failover for multiple IP`() { fun `Client Failover for multiple IP`() {
val amqpServer = createServer(serverPort) val amqpServer = createServer(serverPort)