ENT-1850: Improve reporting of connection problems (#3124)

* Add nicer logging for SSL handshake problems

* Just in case let people see the horrid netty exception traces at trace level
This commit is contained in:
Matthew Nesbit 2018-05-14 09:14:09 +01:00 committed by GitHub
parent e47a84ab49
commit 3c005789c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 231 additions and 20 deletions

View File

@ -83,24 +83,40 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
val sslHandler = ctx.pipeline().get(SslHandler::class.java)
localCert = sslHandler.engine().session.localCertificates[0].x509
remoteCert = sslHandler.engine().session.peerCertificates[0].x509
try {
val remoteX500Name = CordaX500Name.build(remoteCert!!.subjectX500Principal)
require(allowedRemoteLegalNames == null || remoteX500Name in allowedRemoteLegalNames)
log.info("handshake completed subject: $remoteX500Name")
val remoteX500Name = try {
CordaX500Name.build(remoteCert!!.subjectX500Principal)
} catch (ex: IllegalArgumentException) {
log.error("Invalid certificate subject", ex)
log.error("Certificate subject not a valid CordaX500Name", ex)
ctx.close()
return
}
if (allowedRemoteLegalNames != null && remoteX500Name !in allowedRemoteLegalNames) {
log.error("Provided certificate subject $remoteX500Name not in expected set $allowedRemoteLegalNames")
ctx.close()
return
}
log.info("Handshake completed with subject: $remoteX500Name")
createAMQPEngine(ctx)
onOpen(Pair(ctx.channel() as SocketChannel, ConnectionChange(remoteAddress, remoteCert, true)))
} else {
log.error("Handshake failure $evt")
log.error("Handshake failure ${evt.cause().message}")
if (log.isTraceEnabled) {
log.trace("Handshake failure", evt.cause())
}
ctx.close()
}
}
}
@Suppress("OverridingDeprecatedMember")
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
log.warn("Closing channel due to nonrecoverable exception ${cause.message}")
if (log.isTraceEnabled) {
log.trace("Pipeline uncaught exception", cause)
}
ctx.close()
}
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
try {
if (msg is ByteBuf) {

View File

@ -17,6 +17,7 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import rx.Observable
import rx.subjects.PublishSubject
import java.lang.Long.min
import java.security.KeyStore
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
@ -47,7 +48,9 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
}
val log = contextLogger()
const val RETRY_INTERVAL = 1000L
const val MIN_RETRY_INTERVAL = 1000L
const val MAX_RETRY_INTERVAL = 60000L
const val BACKOFF_MULTIPLIER = 2L
const val NUM_CLIENT_THREADS = 2
}
@ -60,6 +63,13 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
// Offset into the list of targets, so that we can implement round-robin reconnect logic.
private var targetIndex = 0
private var currentTarget: NetworkHostAndPort = targets.first()
private var retryInterval = MIN_RETRY_INTERVAL
private fun nextTarget() {
targetIndex = (targetIndex + 1).rem(targets.size)
log.info("Retry connect to ${targets[targetIndex]}")
retryInterval = min(MAX_RETRY_INTERVAL, retryInterval * BACKOFF_MULTIPLIER)
}
private val connectListener = object : ChannelFutureListener {
override fun operationComplete(future: ChannelFuture) {
@ -68,10 +78,9 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
if (!stopping) {
workerGroup?.schedule({
log.info("Retry connect to $currentTarget")
targetIndex = (targetIndex + 1).rem(targets.size)
nextTarget()
restart()
}, RETRY_INTERVAL, TimeUnit.MILLISECONDS)
}, retryInterval, TimeUnit.MILLISECONDS)
}
} else {
log.info("Connected to $currentTarget")
@ -89,10 +98,9 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
clientChannel = null
if (!stopping) {
workerGroup?.schedule({
log.info("Retry connect")
targetIndex = (targetIndex + 1).rem(targets.size)
nextTarget()
restart()
}, RETRY_INTERVAL, TimeUnit.MILLISECONDS)
}, retryInterval, TimeUnit.MILLISECONDS)
}
}
}
@ -116,7 +124,10 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
parent.userName,
parent.password,
parent.trace,
{ parent._onConnection.onNext(it.second) },
{
parent.retryInterval = MIN_RETRY_INTERVAL // reset to fast reconnect if we connect properly
parent._onConnection.onNext(it.second)
},
{ parent._onConnection.onNext(it.second) },
{ rcv -> parent._onReceive.onNext(rcv) }))
}

View File

@ -2,22 +2,111 @@ package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.handler.ssl.SslHandler
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.toHex
import net.corda.nodeapi.ArtemisTcpTransport
import net.corda.nodeapi.internal.crypto.toBc
import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import java.net.Socket
import java.security.KeyStore
import java.security.SecureRandom
import java.security.cert.CertPathBuilder
import java.security.cert.PKIXBuilderParameters
import java.security.cert.PKIXRevocationChecker
import java.security.cert.X509CertSelector
import java.security.cert.*
import java.util.*
import javax.net.ssl.*
internal class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() {
companion object {
val log = contextLogger()
}
private fun certPathToString(certPath: Array<out X509Certificate>?): String {
if (certPath == null) {
return "<empty certpath>"
}
val certs = certPath.map {
val bcCert = it.toBc()
val subject = bcCert.subject.toString()
val issuer = bcCert.issuer.toString()
val keyIdentifier = try {
SubjectKeyIdentifier.getInstance(bcCert.getExtension(Extension.subjectKeyIdentifier).parsedValue).keyIdentifier.toHex()
} catch (ex: Exception) {
"null"
}
val authorityKeyIdentifier = try {
AuthorityKeyIdentifier.getInstance(bcCert.getExtension(Extension.authorityKeyIdentifier).parsedValue).keyIdentifier.toHex()
} catch (ex: Exception) {
"null"
}
" $subject[$keyIdentifier] issued by $issuer[$authorityKeyIdentifier]"
}
return certs.joinToString("\r\n")
}
private fun certPathToStringFull(chain: Array<out X509Certificate>?): String {
if (chain == null) {
return "<empty certpath>"
}
return chain.map { it.toString() }.joinToString(", ")
}
private fun logErrors(chain: Array<out X509Certificate>?, block: () -> Unit) {
try {
block()
} catch (ex: CertificateException) {
log.error("Bad certificate path ${ex.message}:\r\n${certPathToStringFull(chain)}")
throw ex
}
}
@Throws(CertificateException::class)
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?, socket: Socket?) {
log.info("Check Client Certpath:\r\n${certPathToString(chain)}")
logErrors(chain) { wrapped.checkClientTrusted(chain, authType, socket) }
}
@Throws(CertificateException::class)
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?, engine: SSLEngine?) {
log.info("Check Client Certpath:\r\n${certPathToString(chain)}")
logErrors(chain) { wrapped.checkClientTrusted(chain, authType, engine) }
}
@Throws(CertificateException::class)
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {
log.info("Check Client Certpath:\r\n${certPathToString(chain)}")
logErrors(chain) { wrapped.checkClientTrusted(chain, authType) }
}
@Throws(CertificateException::class)
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?, socket: Socket?) {
log.info("Check Server Certpath:\r\n${certPathToString(chain)}")
logErrors(chain) { wrapped.checkServerTrusted(chain, authType, socket) }
}
@Throws(CertificateException::class)
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?, engine: SSLEngine?) {
log.info("Check Server Certpath:\r\n${certPathToString(chain)}")
logErrors(chain) { wrapped.checkServerTrusted(chain, authType, engine) }
}
@Throws(CertificateException::class)
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {
log.info("Check Server Certpath:\r\n${certPathToString(chain)}")
logErrors(chain) { wrapped.checkServerTrusted(chain, authType) }
}
override fun getAcceptedIssuers(): Array<X509Certificate> = wrapped.acceptedIssuers
}
internal fun createClientSslHelper(target: NetworkHostAndPort,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers
val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java).map { LoggingTrustManagerWrapper(it) }.toTypedArray()
sslContext.init(keyManagers, trustManagers, SecureRandom())
val sslEngine = sslContext.createSSLEngine(target.host, target.port)
sslEngine.useClientMode = true
@ -31,7 +120,7 @@ internal fun createServerSslHelper(keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers
val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java).map { LoggingTrustManagerWrapper(it) }.toTypedArray()
sslContext.init(keyManagers, trustManagers, SecureRandom())
val sslEngine = sslContext.createSSLEngine()
sslEngine.useClientMode = false

View File

@ -12,20 +12,30 @@ import net.corda.node.services.config.CertChainPolicyConfig
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.node.services.messaging.ArtemisMessagingServer
import net.corda.nodeapi.ArtemisTcpTransport.Companion.CIPHER_SUITES
import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEER_USER
import net.corda.nodeapi.internal.config.SSLConfiguration
import net.corda.nodeapi.internal.createDevKeyStores
import net.corda.nodeapi.internal.crypto.*
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.testing.core.*
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.corda.testing.internal.rigorousMock
import org.apache.activemq.artemis.api.core.RoutingType
import org.junit.Assert.assertArrayEquals
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import java.security.SecureRandom
import java.security.cert.X509Certificate
import javax.net.ssl.*
import kotlin.concurrent.thread
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class ProtonWrapperTests {
@Rule
@ -86,6 +96,91 @@ class ProtonWrapperTests {
}
}
private fun SSLConfiguration.createTrustStore(rootCert: X509Certificate) {
val trustStore = loadOrCreateKeyStore(trustStoreFile, trustStorePassword)
trustStore.addOrReplaceCertificate(X509Utilities.CORDA_ROOT_CA, rootCert)
trustStore.save(trustStoreFile, trustStorePassword)
}
@Test
fun `Test AMQP Client with invalid root certificate`() {
val sslConfig = object : SSLConfiguration {
override val certificatesDirectory = temporaryFolder.root.toPath()
override val keyStorePassword = "serverstorepass"
override val trustStorePassword = "trustpass"
override val crlCheckSoftFail: Boolean = true
}
val (rootCa, intermediateCa) = createDevIntermediateCaCertPath()
// Generate server cert and private key and populate another keystore suitable for SSL
sslConfig.createDevKeyStores(ALICE_NAME, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate)
val keyStore = loadKeyStore(sslConfig.sslKeystore, sslConfig.keyStorePassword)
val trustStore = loadKeyStore(sslConfig.trustStoreFile, sslConfig.trustStorePassword)
val context = SSLContext.getInstance("TLS")
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore, sslConfig.keyStorePassword.toCharArray())
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustManagers = trustMgrFactory.trustManagers
context.init(keyManagers, trustManagers, SecureRandom())
val serverSocketFactory = context.serverSocketFactory
val serverSocket = serverSocketFactory.createServerSocket(serverPort) as SSLServerSocket
val serverParams = SSLParameters(CIPHER_SUITES.toTypedArray(),
arrayOf("TLSv1.2"))
serverParams.wantClientAuth = true
serverParams.needClientAuth = true
serverParams.endpointIdentificationAlgorithm = null // Reconfirm default no server name indication, use our own validator.
serverSocket.sslParameters = serverParams
serverSocket.useClientMode = false
val lock = Object()
var done = false
var handshakeError = false
val serverThread = thread {
try {
val sslServerSocket = serverSocket.accept() as SSLSocket
sslServerSocket.addHandshakeCompletedListener {
done = true
}
sslServerSocket.startHandshake()
synchronized(lock) {
while (!done) {
lock.wait(1000)
}
}
sslServerSocket.close()
} catch (ex: SSLHandshakeException) {
handshakeError = true
}
}
val amqpClient = createClient()
amqpClient.use {
val clientConnected = amqpClient.onConnection.toFuture()
amqpClient.start()
val clientConnect = clientConnected.get()
assertEquals(false, clientConnect.connected)
synchronized(lock) {
done = true
lock.notifyAll()
}
}
serverThread.join(1000)
assertTrue(handshakeError)
serverSocket.close()
assertTrue(done)
}
@Test
fun `Client Failover for multiple IP`() {
val amqpServer = createServer(serverPort)