diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt index c4a4e8ebb8..4cba1c6d95 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt @@ -58,7 +58,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, private var remoteCert: X509Certificate? = null private var eventProcessor: EventProcessor? = null private var suppressClose: Boolean = false - private var badCert: Boolean = false + private var connectionResult: ConnectionResult = ConnectionResult.NO_ERROR private var localCert: X509Certificate? = null private var requestedServerName: String? = null @@ -131,7 +131,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, val ch = ctx.channel() logInfoWithMDC { "Closed client connection ${ch.id()} from $remoteAddress to ${ch.localAddress()}" } if (!suppressClose) { - onClose(ch as SocketChannel, ConnectionChange(remoteAddress, remoteCert, false, badCert)) + onClose(ch as SocketChannel, ConnectionChange(remoteAddress, remoteCert, false, connectionResult)) } eventProcessor?.close() ctx.fireChannelInactive() @@ -274,13 +274,13 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, val remoteX500Name = try { CordaX500Name.build(remoteCert!!.subjectX500Principal) } catch (ex: IllegalArgumentException) { - badCert = true + connectionResult = ConnectionResult.HANDSHAKE_FAILURE logErrorWithMDC("Certificate subject not a valid CordaX500Name", ex) ctx.close() return } if (allowedRemoteLegalNames != null && remoteX500Name !in allowedRemoteLegalNames) { - badCert = true + connectionResult = ConnectionResult.HANDSHAKE_FAILURE logErrorWithMDC("Provided certificate subject $remoteX500Name not in expected set $allowedRemoteLegalNames") ctx.close() return @@ -288,7 +288,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, logInfoWithMDC { "Handshake completed with subject: $remoteX500Name, requested server name: ${sslHandler.getRequestedServerName()}." } createAMQPEngine(ctx) - onOpen(ctx.channel() as SocketChannel, ConnectionChange(remoteAddress, remoteCert, connected = true, badCert = false)) + onOpen(ctx.channel() as SocketChannel, ConnectionChange(remoteAddress, remoteCert, connected = true, connectionResult = ConnectionResult.NO_ERROR)) } private fun handleFailedHandshake(ctx: ChannelHandlerContext, evt: SslHandshakeCompletionEvent) { @@ -303,7 +303,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, // io.netty.handler.ssl.SslHandler.setHandshakeFailureTransportFailure() cause is SSLException && (cause.message?.contains("writing TLS control frames") == true) -> logWarnWithMDC(cause.message!!) cause is SSLException && (cause.message?.contains("internal_error") == true) -> logWarnWithMDC("Received internal_error during handshake") - else -> badCert = true + else -> connectionResult = ConnectionResult.HANDSHAKE_FAILURE } logWarnWithMDC("Handshake failure: ${evt.cause().message}") if (log.isTraceEnabled) { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt index 4551608054..4c8b78d57f 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt @@ -26,6 +26,7 @@ import rx.Observable import rx.subjects.PublishSubject import java.lang.Long.min import java.net.InetSocketAddress +import java.time.Duration import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock import javax.net.ssl.KeyManagerFactory @@ -70,6 +71,7 @@ class AMQPClient(val targets: List, private const val MAX_RETRY_INTERVAL = 60000L private const val BACKOFF_MULTIPLIER = 2L private val NUM_CLIENT_THREADS = Integer.getInteger(CORDA_AMQP_NUM_CLIENT_THREAD_PROP_NAME, 2) + private val handshakeRetryIntervals = List(5) { Duration.ofMinutes(5) } } private val lock = ReentrantLock() @@ -82,7 +84,9 @@ class AMQPClient(val targets: List, private var targetIndex = 0 private var currentTarget: NetworkHostAndPort = targets.first() private var retryInterval = MIN_RETRY_INTERVAL - private val badCertTargets = mutableSetOf() + private val handshakeFailureRetryTargets = mutableSetOf() + private var retryingHandshakeFailures = false + private var retryOffset = 0 @Volatile private var amqpActive = false @Volatile @@ -91,22 +95,67 @@ class AMQPClient(val targets: List, val localAddressString: String get() = clientChannel?.localAddress()?.toString() ?: "" - private fun nextTarget() { + /* + Figure out the index of the next address to try to connect to + */ + private fun setTargetIndex() { val origIndex = targetIndex targetIndex = -1 for (offset in 1..targets.size) { val newTargetIndex = (origIndex + offset).rem(targets.size) - if (targets[newTargetIndex] !in badCertTargets) { + if (targets[newTargetIndex] !in handshakeFailureRetryTargets ) { targetIndex = newTargetIndex break } } - if (targetIndex == -1) { - log.error("No targets have presented acceptable certificates for $allowedRemoteLegalNames. Halting retries") - return + } + + /* + Set how long to wait until trying to connect to the next address + */ + private fun setTargetRetryInterval() { + retryInterval = if (retryingHandshakeFailures) { + if (retryOffset < handshakeRetryIntervals.size) { + handshakeRetryIntervals[retryOffset++].toMillis() + } else { + Duration.ofDays(1).toMillis() + } + } else { + min(MAX_RETRY_INTERVAL, retryInterval * BACKOFF_MULTIPLIER) } - log.info("Retry connect to ${targets[targetIndex]}") - retryInterval = min(MAX_RETRY_INTERVAL, retryInterval * BACKOFF_MULTIPLIER) + } + + /* + Once a connection is made, reset all the retry-connection info so if there is another connection failure + then this node tries to reconnect quickly. + */ + private fun successfullyConnected() { + log.info("Successfully connected to [${targets[targetIndex]}]; resetting the target connection-retry interval") + retryingHandshakeFailures = false + retryInterval = MIN_RETRY_INTERVAL + retryOffset = 0 + } + + /* + Set the next target to connect to + */ + private fun nextTarget() { + setTargetIndex() + + if (targetIndex == -1) { + if (handshakeFailureRetryTargets.isNotEmpty()) { + log.info("Failed to connect to any targets. Retrying targets that previously failed to handshake.") + handshakeFailureRetryTargets.clear() + retryingHandshakeFailures = true + setTargetIndex() + } else { + log.error("Attempted connection to targets: $targets, but none of them have presented acceptable certificates" + + " for $allowedRemoteLegalNames. Halting retries.") + return + } + } + setTargetRetryInterval() + log.info("Retry connect to ${targets[targetIndex]} in [$retryInterval] ms") } private val connectListener = object : ChannelFutureListener { @@ -212,7 +261,7 @@ class AMQPClient(val targets: List, onOpen = { _, change -> parent.run { amqpActive = true - retryInterval = MIN_RETRY_INTERVAL // reset to fast reconnect if we connect properly + successfullyConnected() _onConnection.onNext(change) } }, @@ -220,9 +269,9 @@ class AMQPClient(val targets: List, if (parent.amqpChannelHandler == amqpChannelHandler) { parent.run { _onConnection.onNext(change) - if (change.badCert) { - log.error("Blocking future connection attempts to $target due to bad certificate on endpoint") - badCertTargets += target + if (change.connectionResult == ConnectionResult.HANDSHAKE_FAILURE) { + log.warn("Handshake failure with $target target; will retry later") + handshakeFailureRetryTargets += target } if (started && amqpActive) { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/ConnectionChange.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/ConnectionChange.kt index da839954ce..e900f93306 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/ConnectionChange.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/ConnectionChange.kt @@ -3,8 +3,8 @@ package net.corda.nodeapi.internal.protonwrapper.netty import java.net.InetSocketAddress import java.security.cert.X509Certificate -data class ConnectionChange(val remoteAddress: InetSocketAddress, val remoteCert: X509Certificate?, val connected: Boolean, val badCert: Boolean) { +data class ConnectionChange(val remoteAddress: InetSocketAddress, val remoteCert: X509Certificate?, val connected: Boolean, val connectionResult: ConnectionResult) { override fun toString(): String { - return "ConnectionChange remoteAddress: $remoteAddress connected state: $connected cert subject: ${remoteCert?.subjectDN} cert ok: ${!badCert}" + return "ConnectionChange remoteAddress: $remoteAddress connected state: $connected cert subject: ${remoteCert?.subjectDN} result: ${connectionResult}" } -} \ No newline at end of file +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/ConnectionResult.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/ConnectionResult.kt new file mode 100644 index 0000000000..fbd67bc138 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/ConnectionResult.kt @@ -0,0 +1,6 @@ +package net.corda.nodeapi.internal.protonwrapper.netty + +enum class ConnectionResult { + NO_ERROR, + HANDSHAKE_FAILURE +} diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt index b1bf4b99f8..dd85bae8e7 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt @@ -14,6 +14,7 @@ import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration +import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionResult import net.corda.nodeapi.internal.protonwrapper.netty.init import net.corda.nodeapi.internal.protonwrapper.netty.initialiseTrustStoreAndEnableCrlChecking import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig @@ -29,6 +30,7 @@ import org.junit.runner.RunWith import org.junit.runners.Parameterized import javax.net.ssl.KeyManagerFactory import javax.net.ssl.TrustManagerFactory +import kotlin.test.assertEquals import kotlin.test.assertFalse import kotlin.test.assertTrue @@ -211,7 +213,7 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) { val clientConnect = clientConnected.get() assertFalse(clientConnect.connected) // Not a badCert, but a timeout during handshake - assertFalse(clientConnect.badCert) + assertEquals(ConnectionResult.NO_ERROR, clientConnect.connectionResult) } } assertFalse(serverThread.isActive) diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt index 3f0ac98f4c..b067f24b40 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt @@ -36,6 +36,7 @@ import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig import org.apache.activemq.artemis.api.core.QueueConfiguration import org.apache.activemq.artemis.api.core.RoutingType +import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Assert.assertArrayEquals import org.junit.Rule @@ -207,6 +208,103 @@ class ProtonWrapperTests { assertTrue(done) } + @Suppress("TooGenericExceptionCaught") // Too generic exception thrown! + @Test(timeout=300_000) + fun `AMPQClient that fails to handshake with a server will retry the server`() { + /* + This test has been modelled on `Test AMQP Client with invalid root certificate`, above. + The aim is to set up a server with an invalid root cert so that the TLS handshake will fail. + The test allows the AMQPClient to retry the connection (which it should do). + */ + + val certificatesDirectory = temporaryFolder.root.toPath() + val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory, "serverstorepass") + val sslConfig = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, keyStorePassword = "serverstorepass") + + val (rootCa, intermediateCa) = createDevIntermediateCaCertPath() + + // Generate server cert and private key and populate another keystore suitable for SSL + signingCertificateStore.get(true).also { it.installDevNodeCaCertPath(ALICE_NAME, rootCa.certificate, intermediateCa) } + sslConfig.keyStore.get(true).also { it.registerDevP2pCertificates(ALICE_NAME, rootCa.certificate, intermediateCa) } + sslConfig.createTrustStore(rootCa.certificate) + + val keyStore = sslConfig.keyStore.get() + val trustStore = sslConfig.trustStore.get() + + val context = SSLContext.getInstance("TLS") + val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) + keyManagerFactory.init(keyStore) + val keyManagers = keyManagerFactory.keyManagers + val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + trustMgrFactory.init(trustStore) + val trustManagers = trustMgrFactory.trustManagers + context.init(keyManagers, trustManagers, newSecureRandom()) + + val serverSocketFactory = context.serverSocketFactory + + val serverSocket = serverSocketFactory.createServerSocket(serverPort) as SSLServerSocket + val serverParams = SSLParameters(ArtemisTcpTransport.CIPHER_SUITES.toTypedArray(), + arrayOf("TLSv1.2")) + serverParams.wantClientAuth = true + serverParams.needClientAuth = true + serverParams.endpointIdentificationAlgorithm = null // Reconfirm default no server name indication, use our own validator. + serverSocket.sslParameters = serverParams + serverSocket.useClientMode = false + + var done = false + var handshakeErrorCount = 0 + + // + // This is the thread that acts as the server-side endpoint for the AMQPClient to connect to. + // + val serverThread = thread { + // + // The server thread will keep making itself available for SSL connections until + // the 'done' flag is set by the client thread, later on. + // + while (!done) { + try { + val sslServerSocket = serverSocket.accept() as SSLSocket + sslServerSocket.addHandshakeCompletedListener { + done = true + } + sslServerSocket.startHandshake() + } catch (ex: SSLException) { + ++handshakeErrorCount + } catch (e: Throwable) { + println(e) + } + } + } + + // + // Create the AMQPClient but only specify one server endpoint to connect to. + // + val amqpClient = createClient(serverAddressList = listOf(NetworkHostAndPort("localhost", serverPort))) + amqpClient.use { + + amqpClient.start() + // + // Waiting for the number of handshake errors to get to at least 2. + // This happens when the AMQPClient has made it's first retry attempt, which is + // what this test is interested in. + // + while (handshakeErrorCount < 2) { + Thread.sleep(2) + } + done = true + } + + serverThread.join(1000) + // + // check that there was at least one retry i.e. > 1 handshake error. + // + Assertions.assertThat(handshakeErrorCount > 1).isTrue() + + serverSocket.close() + assertTrue(done) + } + @Test(timeout=300_000) fun `Client Failover for multiple IP`() { @@ -450,7 +548,11 @@ class ProtonWrapperTests { return Pair(server, client) } - private fun createClient(maxMessageSize: Int = MAX_MESSAGE_SIZE): AMQPClient { + private fun createClient(maxMessageSize: Int = MAX_MESSAGE_SIZE, + serverAddressList: List = listOf( + NetworkHostAndPort("localhost", serverPort), + NetworkHostAndPort("localhost", serverPort2), + NetworkHostAndPort("localhost", artemisPort))): AMQPClient { val baseDirectory = temporaryFolder.root.toPath() / "client" val certificatesDirectory = baseDirectory / "certificates" val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) @@ -474,9 +576,7 @@ class ProtonWrapperTests { override val maxMessageSize: Int = maxMessageSize } return AMQPClient( - listOf(NetworkHostAndPort("localhost", serverPort), - NetworkHostAndPort("localhost", serverPort2), - NetworkHostAndPort("localhost", artemisPort)), + serverAddressList, setOf(ALICE_NAME, CHARLIE_NAME), amqpConfig) }