Do not black-list AMQP targets that suffer a handshake failure

This commit is contained in:
Chris Cochrane 2022-09-13 11:41:19 +01:00
parent 242d7d45c5
commit 5ca5b8d096
No known key found for this signature in database
GPG Key ID: 4D4602B5BBC63950
6 changed files with 183 additions and 26 deletions

View File

@ -58,7 +58,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
private var remoteCert: X509Certificate? = null
private var eventProcessor: EventProcessor? = null
private var suppressClose: Boolean = false
private var badCert: Boolean = false
private var connectionResult: ConnectionResult = ConnectionResult.NO_ERROR
private var localCert: X509Certificate? = null
private var requestedServerName: String? = null
@ -131,7 +131,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
val ch = ctx.channel()
logInfoWithMDC { "Closed client connection ${ch.id()} from $remoteAddress to ${ch.localAddress()}" }
if (!suppressClose) {
onClose(ch as SocketChannel, ConnectionChange(remoteAddress, remoteCert, false, badCert))
onClose(ch as SocketChannel, ConnectionChange(remoteAddress, remoteCert, false, connectionResult))
}
eventProcessor?.close()
ctx.fireChannelInactive()
@ -274,13 +274,13 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
val remoteX500Name = try {
CordaX500Name.build(remoteCert!!.subjectX500Principal)
} catch (ex: IllegalArgumentException) {
badCert = true
connectionResult = ConnectionResult.HANDSHAKE_FAILURE
logErrorWithMDC("Certificate subject not a valid CordaX500Name", ex)
ctx.close()
return
}
if (allowedRemoteLegalNames != null && remoteX500Name !in allowedRemoteLegalNames) {
badCert = true
connectionResult = ConnectionResult.HANDSHAKE_FAILURE
logErrorWithMDC("Provided certificate subject $remoteX500Name not in expected set $allowedRemoteLegalNames")
ctx.close()
return
@ -288,7 +288,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
logInfoWithMDC { "Handshake completed with subject: $remoteX500Name, requested server name: ${sslHandler.getRequestedServerName()}." }
createAMQPEngine(ctx)
onOpen(ctx.channel() as SocketChannel, ConnectionChange(remoteAddress, remoteCert, connected = true, badCert = false))
onOpen(ctx.channel() as SocketChannel, ConnectionChange(remoteAddress, remoteCert, connected = true, connectionResult = ConnectionResult.NO_ERROR))
}
private fun handleFailedHandshake(ctx: ChannelHandlerContext, evt: SslHandshakeCompletionEvent) {
@ -303,7 +303,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
// io.netty.handler.ssl.SslHandler.setHandshakeFailureTransportFailure()
cause is SSLException && (cause.message?.contains("writing TLS control frames") == true) -> logWarnWithMDC(cause.message!!)
cause is SSLException && (cause.message?.contains("internal_error") == true) -> logWarnWithMDC("Received internal_error during handshake")
else -> badCert = true
else -> connectionResult = ConnectionResult.HANDSHAKE_FAILURE
}
logWarnWithMDC("Handshake failure: ${evt.cause().message}")
if (log.isTraceEnabled) {

View File

@ -26,6 +26,7 @@ import rx.Observable
import rx.subjects.PublishSubject
import java.lang.Long.min
import java.net.InetSocketAddress
import java.time.Duration
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
@ -70,6 +71,7 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
private const val MAX_RETRY_INTERVAL = 60000L
private const val BACKOFF_MULTIPLIER = 2L
private val NUM_CLIENT_THREADS = Integer.getInteger(CORDA_AMQP_NUM_CLIENT_THREAD_PROP_NAME, 2)
private val handshakeRetryIntervals = List(5) { Duration.ofMinutes(5) }
}
private val lock = ReentrantLock()
@ -82,7 +84,9 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
private var targetIndex = 0
private var currentTarget: NetworkHostAndPort = targets.first()
private var retryInterval = MIN_RETRY_INTERVAL
private val badCertTargets = mutableSetOf<NetworkHostAndPort>()
private val handshakeFailureRetryTargets = mutableSetOf<NetworkHostAndPort>()
private var retryingHandshakeFailures = false
private var retryOffset = 0
@Volatile
private var amqpActive = false
@Volatile
@ -91,22 +95,67 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
val localAddressString: String
get() = clientChannel?.localAddress()?.toString() ?: "<unknownLocalAddress>"
private fun nextTarget() {
/*
Figure out the index of the next address to try to connect to
*/
private fun setTargetIndex() {
val origIndex = targetIndex
targetIndex = -1
for (offset in 1..targets.size) {
val newTargetIndex = (origIndex + offset).rem(targets.size)
if (targets[newTargetIndex] !in badCertTargets) {
if (targets[newTargetIndex] !in handshakeFailureRetryTargets ) {
targetIndex = newTargetIndex
break
}
}
if (targetIndex == -1) {
log.error("No targets have presented acceptable certificates for $allowedRemoteLegalNames. Halting retries")
return
}
/*
Set how long to wait until trying to connect to the next address
*/
private fun setTargetRetryInterval() {
retryInterval = if (retryingHandshakeFailures) {
if (retryOffset < handshakeRetryIntervals.size) {
handshakeRetryIntervals[retryOffset++].toMillis()
} else {
Duration.ofDays(1).toMillis()
}
} else {
min(MAX_RETRY_INTERVAL, retryInterval * BACKOFF_MULTIPLIER)
}
log.info("Retry connect to ${targets[targetIndex]}")
retryInterval = min(MAX_RETRY_INTERVAL, retryInterval * BACKOFF_MULTIPLIER)
}
/*
Once a connection is made, reset all the retry-connection info so if there is another connection failure
then this node tries to reconnect quickly.
*/
private fun successfullyConnected() {
log.info("Successfully connected to [${targets[targetIndex]}]; resetting the target connection-retry interval")
retryingHandshakeFailures = false
retryInterval = MIN_RETRY_INTERVAL
retryOffset = 0
}
/*
Set the next target to connect to
*/
private fun nextTarget() {
setTargetIndex()
if (targetIndex == -1) {
if (handshakeFailureRetryTargets.isNotEmpty()) {
log.info("Failed to connect to any targets. Retrying targets that previously failed to handshake.")
handshakeFailureRetryTargets.clear()
retryingHandshakeFailures = true
setTargetIndex()
} else {
log.error("Attempted connection to targets: $targets, but none of them have presented acceptable certificates" +
" for $allowedRemoteLegalNames. Halting retries.")
return
}
}
setTargetRetryInterval()
log.info("Retry connect to ${targets[targetIndex]} in [$retryInterval] ms")
}
private val connectListener = object : ChannelFutureListener {
@ -212,7 +261,7 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
onOpen = { _, change ->
parent.run {
amqpActive = true
retryInterval = MIN_RETRY_INTERVAL // reset to fast reconnect if we connect properly
successfullyConnected()
_onConnection.onNext(change)
}
},
@ -220,9 +269,9 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
if (parent.amqpChannelHandler == amqpChannelHandler) {
parent.run {
_onConnection.onNext(change)
if (change.badCert) {
log.error("Blocking future connection attempts to $target due to bad certificate on endpoint")
badCertTargets += target
if (change.connectionResult == ConnectionResult.HANDSHAKE_FAILURE) {
log.warn("Handshake failure with $target target; will retry later")
handshakeFailureRetryTargets += target
}
if (started && amqpActive) {

View File

@ -3,8 +3,8 @@ package net.corda.nodeapi.internal.protonwrapper.netty
import java.net.InetSocketAddress
import java.security.cert.X509Certificate
data class ConnectionChange(val remoteAddress: InetSocketAddress, val remoteCert: X509Certificate?, val connected: Boolean, val badCert: Boolean) {
data class ConnectionChange(val remoteAddress: InetSocketAddress, val remoteCert: X509Certificate?, val connected: Boolean, val connectionResult: ConnectionResult) {
override fun toString(): String {
return "ConnectionChange remoteAddress: $remoteAddress connected state: $connected cert subject: ${remoteCert?.subjectDN} cert ok: ${!badCert}"
return "ConnectionChange remoteAddress: $remoteAddress connected state: $connected cert subject: ${remoteCert?.subjectDN} result: ${connectionResult}"
}
}
}

View File

@ -0,0 +1,6 @@
package net.corda.nodeapi.internal.protonwrapper.netty
enum class ConnectionResult {
NO_ERROR,
HANDSHAKE_FAILURE
}

View File

@ -14,6 +14,7 @@ import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionResult
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.initialiseTrustStoreAndEnableCrlChecking
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
@ -29,6 +30,7 @@ import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
@ -211,7 +213,7 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) {
val clientConnect = clientConnected.get()
assertFalse(clientConnect.connected)
// Not a badCert, but a timeout during handshake
assertFalse(clientConnect.badCert)
assertEquals(ConnectionResult.NO_ERROR, clientConnect.connectionResult)
}
}
assertFalse(serverThread.isActive)

View File

@ -36,6 +36,7 @@ import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import org.apache.activemq.artemis.api.core.QueueConfiguration
import org.apache.activemq.artemis.api.core.RoutingType
import org.assertj.core.api.Assertions
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Assert.assertArrayEquals
import org.junit.Rule
@ -207,6 +208,103 @@ class ProtonWrapperTests {
assertTrue(done)
}
@Suppress("TooGenericExceptionCaught") // Too generic exception thrown!
@Test(timeout=300_000)
fun `AMPQClient that fails to handshake with a server will retry the server`() {
/*
This test has been modelled on `Test AMQP Client with invalid root certificate`, above.
The aim is to set up a server with an invalid root cert so that the TLS handshake will fail.
The test allows the AMQPClient to retry the connection (which it should do).
*/
val certificatesDirectory = temporaryFolder.root.toPath()
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory, "serverstorepass")
val sslConfig = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, keyStorePassword = "serverstorepass")
val (rootCa, intermediateCa) = createDevIntermediateCaCertPath()
// Generate server cert and private key and populate another keystore suitable for SSL
signingCertificateStore.get(true).also { it.installDevNodeCaCertPath(ALICE_NAME, rootCa.certificate, intermediateCa) }
sslConfig.keyStore.get(true).also { it.registerDevP2pCertificates(ALICE_NAME, rootCa.certificate, intermediateCa) }
sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val context = SSLContext.getInstance("TLS")
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustManagers = trustMgrFactory.trustManagers
context.init(keyManagers, trustManagers, newSecureRandom())
val serverSocketFactory = context.serverSocketFactory
val serverSocket = serverSocketFactory.createServerSocket(serverPort) as SSLServerSocket
val serverParams = SSLParameters(ArtemisTcpTransport.CIPHER_SUITES.toTypedArray(),
arrayOf("TLSv1.2"))
serverParams.wantClientAuth = true
serverParams.needClientAuth = true
serverParams.endpointIdentificationAlgorithm = null // Reconfirm default no server name indication, use our own validator.
serverSocket.sslParameters = serverParams
serverSocket.useClientMode = false
var done = false
var handshakeErrorCount = 0
//
// This is the thread that acts as the server-side endpoint for the AMQPClient to connect to.
//
val serverThread = thread {
//
// The server thread will keep making itself available for SSL connections until
// the 'done' flag is set by the client thread, later on.
//
while (!done) {
try {
val sslServerSocket = serverSocket.accept() as SSLSocket
sslServerSocket.addHandshakeCompletedListener {
done = true
}
sslServerSocket.startHandshake()
} catch (ex: SSLException) {
++handshakeErrorCount
} catch (e: Throwable) {
println(e)
}
}
}
//
// Create the AMQPClient but only specify one server endpoint to connect to.
//
val amqpClient = createClient(serverAddressList = listOf(NetworkHostAndPort("localhost", serverPort)))
amqpClient.use {
amqpClient.start()
//
// Waiting for the number of handshake errors to get to at least 2.
// This happens when the AMQPClient has made it's first retry attempt, which is
// what this test is interested in.
//
while (handshakeErrorCount < 2) {
Thread.sleep(2)
}
done = true
}
serverThread.join(1000)
//
// check that there was at least one retry i.e. > 1 handshake error.
//
Assertions.assertThat(handshakeErrorCount > 1).isTrue()
serverSocket.close()
assertTrue(done)
}
@Test(timeout=300_000)
fun `Client Failover for multiple IP`() {
@ -450,7 +548,11 @@ class ProtonWrapperTests {
return Pair(server, client)
}
private fun createClient(maxMessageSize: Int = MAX_MESSAGE_SIZE): AMQPClient {
private fun createClient(maxMessageSize: Int = MAX_MESSAGE_SIZE,
serverAddressList: List<NetworkHostAndPort> = listOf(
NetworkHostAndPort("localhost", serverPort),
NetworkHostAndPort("localhost", serverPort2),
NetworkHostAndPort("localhost", artemisPort))): AMQPClient {
val baseDirectory = temporaryFolder.root.toPath() / "client"
val certificatesDirectory = baseDirectory / "certificates"
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
@ -474,9 +576,7 @@ class ProtonWrapperTests {
override val maxMessageSize: Int = maxMessageSize
}
return AMQPClient(
listOf(NetworkHostAndPort("localhost", serverPort),
NetworkHostAndPort("localhost", serverPort2),
NetworkHostAndPort("localhost", artemisPort)),
serverAddressList,
setOf(ALICE_NAME, CHARLIE_NAME),
amqpConfig)
}