diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt index e7452d49ab..d9602f0c43 100644 --- a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt @@ -326,6 +326,96 @@ class RPCStabilityTests { } } + interface ServerOps : RPCOps { + fun serverId(): String + } + + @Test + fun `client connects to first available server`() { + rpcDriver { + val ops = object : ServerOps { + override val protocolVersion = 0 + override fun serverId() = "server" + } + val serverFollower = shutdownManager.follower() + val serverAddress = startRpcServer(ops = ops).getOrThrow().broker.hostAndPort!! + serverFollower.unfollow() + + val clientFollower = shutdownManager.follower() + val client = startRpcClient(listOf(NetworkHostAndPort("localhost", 12345), serverAddress, NetworkHostAndPort("localhost", 54321))).getOrThrow() + clientFollower.unfollow() + + assertEquals("server", client.serverId()) + + clientFollower.shutdown() // Driver would do this after the new server, causing hang. + } + } + + @Test + fun `3 server failover`() { + rpcDriver { + val ops1 = object : ServerOps { + override val protocolVersion = 0 + override fun serverId() = "server1" + } + val ops2 = object : ServerOps { + override val protocolVersion = 0 + override fun serverId() = "server2" + } + val ops3 = object : ServerOps { + override val protocolVersion = 0 + override fun serverId() = "server3" + } + val serverFollower1 = shutdownManager.follower() + val server1 = startRpcServer(ops = ops1).getOrThrow() + serverFollower1.unfollow() + + val serverFollower2 = shutdownManager.follower() + val server2 = startRpcServer(ops = ops2).getOrThrow() + serverFollower2.unfollow() + + val serverFollower3 = shutdownManager.follower() + val server3 = startRpcServer(ops = ops3).getOrThrow() + serverFollower3.unfollow() + val servers = mutableMapOf("server1" to serverFollower1, "server2" to serverFollower2, "server3" to serverFollower3) + + val clientFollower = shutdownManager.follower() + val client = startRpcClient(listOf(server1.broker.hostAndPort!!, server2.broker.hostAndPort!!, server3.broker.hostAndPort!!)).getOrThrow() + clientFollower.unfollow() + + var response = client.serverId() + assertTrue(servers.containsKey(response)) + servers[response]!!.shutdown() + servers.remove(response) + + //failover will take some time + while (true) { + try { + response = client.serverId() + break + } catch (e: RPCException) {} + } + assertTrue(servers.containsKey(response)) + servers[response]!!.shutdown() + servers.remove(response) + + while (true) { + try { + response = client.serverId() + break + } catch (e: RPCException) {} + } + assertTrue(servers.containsKey(response)) + servers[response]!!.shutdown() + servers.remove(response) + + assertTrue(servers.isEmpty()) + + clientFollower.shutdown() // Driver would do this after the new server, causing hang. + + } + } + interface TrackSubscriberOps : RPCOps { fun subscribe(): Observable } diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt index ed4ad14a23..a378808068 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt @@ -93,19 +93,34 @@ interface CordaRPCClientConfiguration { * [CordaRPCClientConfiguration]. While attempting failover, current and future RPC calls will throw * [RPCException] and previously returned observables will call onError(). * + * If the client was created using a list of hosts, automatic failover will occur(the servers have to be started in HA mode) + * * @param hostAndPort The network address to connect to. * @param configuration An optional configuration used to tweak client behaviour. * @param sslConfiguration An optional [SSLConfiguration] used to enable secure communication with the server. + * @param haAddressPool A list of [NetworkHostAndPort] representing the addresses of servers in HA mode. + * The client will attempt to connect to a live server by trying each address in the list. If the servers are not in + * HA mode, the client will round-robin from the beginning of the list and try all servers. */ class CordaRPCClient private constructor( - hostAndPort: NetworkHostAndPort, - configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), - sslConfiguration: SSLConfiguration? = null, - classLoader: ClassLoader? = null + private val hostAndPort: NetworkHostAndPort, + private val configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), + private val sslConfiguration: SSLConfiguration? = null, + private val classLoader: ClassLoader? = null, + private val haAddressPool: List = emptyList() ) { @JvmOverloads constructor(hostAndPort: NetworkHostAndPort, configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default()) : this(hostAndPort, configuration, null) + /** + * @param haAddressPool A list of [NetworkHostAndPort] representing the addresses of servers in HA mode. + * The client will attempt to connect to a live server by trying each address in the list. If the servers are not in + * HA mode, the client will round-robin from the beginning of the list and try all servers. + * @param configuration An optional configuration used to tweak client behaviour. + */ + @JvmOverloads + constructor(haAddressPool: List, configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default()) : this(haAddressPool.first(), configuration, null, null, haAddressPool) + companion object { internal fun createWithSsl( hostAndPort: NetworkHostAndPort, @@ -115,6 +130,14 @@ class CordaRPCClient private constructor( return CordaRPCClient(hostAndPort, configuration, sslConfiguration) } + internal fun createWithSsl( + haAddressPool: List, + configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), + sslConfiguration: SSLConfiguration? = null + ): CordaRPCClient { + return CordaRPCClient(haAddressPool.first(), configuration, sslConfiguration, null, haAddressPool) + } + internal fun createWithSslAndClassLoader( hostAndPort: NetworkHostAndPort, configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), @@ -123,6 +146,15 @@ class CordaRPCClient private constructor( ): CordaRPCClient { return CordaRPCClient(hostAndPort, configuration, sslConfiguration, classLoader) } + + internal fun createWithSslAndClassLoader( + haAddressPool: List, + configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default(), + sslConfiguration: SSLConfiguration? = null, + classLoader: ClassLoader? = null + ): CordaRPCClient { + return CordaRPCClient(haAddressPool.first(), configuration, sslConfiguration, classLoader, haAddressPool) + } } init { @@ -137,11 +169,19 @@ class CordaRPCClient private constructor( } } - private val rpcClient = RPCClient( - tcpTransport(ConnectionDirection.Outbound(), hostAndPort, config = sslConfiguration), - configuration, - if (classLoader != null) KRYO_RPC_CLIENT_CONTEXT.withClassLoader(classLoader) else KRYO_RPC_CLIENT_CONTEXT - ) + private fun getRpcClient() : RPCClient { + return if (haAddressPool.isEmpty()) { + RPCClient( + tcpTransport(ConnectionDirection.Outbound(), hostAndPort, config = sslConfiguration), + configuration, + if (classLoader != null) KRYO_RPC_CLIENT_CONTEXT.withClassLoader(classLoader) else KRYO_RPC_CLIENT_CONTEXT) + } else { + RPCClient(haAddressPool, + sslConfiguration, + configuration, + if (classLoader != null) KRYO_RPC_CLIENT_CONTEXT.withClassLoader(classLoader) else KRYO_RPC_CLIENT_CONTEXT) + } + } /** * Logs in to the target server and returns an active connection. The returned connection is a [java.io.Closeable] @@ -169,7 +209,7 @@ class CordaRPCClient private constructor( * @throws RPCException if the server version is too low or if the server isn't reachable within a reasonable timeout. */ fun start(username: String, password: String, externalTrace: Trace?, impersonatedActor: Actor?): CordaRPCConnection { - return CordaRPCConnection(rpcClient.start(CordaRPCOps::class.java, username, password, externalTrace, impersonatedActor)) + return CordaRPCConnection(getRpcClient().start(CordaRPCOps::class.java, username, password, externalTrace, impersonatedActor)) } /** diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt index e8f33d284f..24096c0952 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt @@ -14,6 +14,7 @@ import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.* import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport +import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransportsFromList import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.RPCApi import net.corda.nodeapi.internal.config.SSLConfiguration @@ -60,7 +61,8 @@ data class CordaRPCClientConfigurationImpl( class RPCClient( val transport: TransportConfiguration, val rpcConfiguration: CordaRPCClientConfiguration = CordaRPCClientConfigurationImpl.default, - val serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT + val serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT, + val haPoolTransportConfigurations: List = emptyList() ) { constructor( hostAndPort: NetworkHostAndPort, @@ -69,6 +71,14 @@ class RPCClient( serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT ) : this(tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), configuration, serializationContext) + constructor( + haAddressPool: List, + sslConfiguration: SSLConfiguration? = null, + configuration: CordaRPCClientConfiguration = CordaRPCClientConfigurationImpl.default, + serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT + ) : this(tcpTransport(ConnectionDirection.Outbound(), haAddressPool.first(), sslConfiguration), + configuration, serializationContext, tcpTransportsFromList(ConnectionDirection.Outbound(), haAddressPool, sslConfiguration)) + companion object { private val log = contextLogger() } @@ -83,11 +93,15 @@ class RPCClient( return log.logElapsedTime("Startup") { val clientAddress = SimpleString("${RPCApi.RPC_CLIENT_QUEUE_NAME_PREFIX}.$username.${random63BitValue()}") - val serverLocator = ActiveMQClient.createServerLocatorWithoutHA(transport).apply { + val serverLocator = (if (haPoolTransportConfigurations.isEmpty()) { + ActiveMQClient.createServerLocatorWithoutHA(transport) + } else { + ActiveMQClient.createServerLocatorWithoutHA(*haPoolTransportConfigurations.toTypedArray()) + }).apply { retryInterval = rpcConfiguration.connectionRetryInterval.toMillis() retryIntervalMultiplier = rpcConfiguration.connectionRetryIntervalMultiplier maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval.toMillis() - reconnectAttempts = rpcConfiguration.maxReconnectAttempts + reconnectAttempts = if (haPoolTransportConfigurations.isEmpty()) rpcConfiguration.maxReconnectAttempts else 0 minLargeMessageSize = rpcConfiguration.maxFileSize isUseGlobalPools = nodeSerializationEnv != null } diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt index dc07b1ada2..8852a096cd 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt @@ -20,6 +20,7 @@ import net.corda.core.context.Trace.InvocationId import net.corda.core.internal.LazyStickyPool import net.corda.core.internal.LifeCycle import net.corda.core.internal.ThreadBox +import net.corda.core.internal.times import net.corda.core.messaging.RPCOps import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.serialize @@ -29,6 +30,7 @@ import net.corda.core.utilities.debug import net.corda.core.utilities.getOrThrow import net.corda.nodeapi.RPCApi import net.corda.nodeapi.internal.DeduplicationChecker +import org.apache.activemq.artemis.api.core.ActiveMQException import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.client.* @@ -173,6 +175,8 @@ class RPCClientProxyHandler( private val deduplicationSequenceNumber = AtomicLong(0) private val sendingEnabled = AtomicBoolean(true) + // used to interrupt failover thread (i.e. client is closed while failing over) + private var haFailoverThread: Thread? = null /** * Start the client. This creates the per-client queue, starts the consumer session and the reaper. @@ -192,17 +196,22 @@ class RPCClientProxyHandler( rpcConfiguration.reapInterval.toMillis(), TimeUnit.MILLISECONDS ) + // Create a session factory using the first available server. If more than one transport configuration was + // used when creating the server locator, every one will be tried during failover. The locator will round-robin + // through the available transport configurations with the starting position being generated randomly. + // If there's only one available, that one will be retried continuously as configured in rpcConfiguration. + // There is no failover on first attempt, meaning that if a connection cannot be established, the serverLocator + // will try another transport if it exists or throw an exception otherwise. sessionFactory = serverLocator.createSessionFactory() - producerSession = sessionFactory!!.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) - rpcProducer = producerSession!!.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME) - consumerSession = sessionFactory!!.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) - consumerSession!!.createTemporaryQueue(clientAddress, RoutingType.ANYCAST, clientAddress) - rpcConsumer = consumerSession!!.createConsumer(clientAddress) - rpcConsumer!!.setMessageHandler(this::artemisMessageHandler) - producerSession!!.addFailoverListener(this::failoverHandler) + // Depending on how the client is constructed, connection failure is treated differently + if (serverLocator.staticTransportConfigurations.size == 1) { + sessionFactory!!.addFailoverListener(this::failoverHandler) + } else { + sessionFactory!!.addFailoverListener(this::haFailoverHandler) + } + initSessions() lifeCycle.transition(State.UNSTARTED, State.SERVER_VERSION_NOT_SET) - consumerSession!!.start() - producerSession!!.start() + startSessions() } // This is the general function that transforms a client side RPC to internal Artemis messages. @@ -341,6 +350,10 @@ class RPCClientProxyHandler( * @param notify whether to notify observables or not. */ private fun close(notify: Boolean = true) { + haFailoverThread?.apply { + interrupt() + join(1000) + } sessionFactory?.close() reaperScheduledFuture?.cancel(false) observableContext.observableMap.invalidateAll() @@ -403,26 +416,82 @@ class RPCClientProxyHandler( } } + private fun attemptReconnect() { + var reconnectAttempts = rpcConfiguration.maxReconnectAttempts * serverLocator.staticTransportConfigurations.size + var retryInterval = rpcConfiguration.connectionRetryInterval + val maxRetryInterval = rpcConfiguration.connectionMaxRetryInterval + + var transportIterator = serverLocator.staticTransportConfigurations.iterator() + while (transportIterator.hasNext() && reconnectAttempts != 0) { + val transport = transportIterator.next() + if (!transportIterator.hasNext()) + transportIterator = serverLocator.staticTransportConfigurations.iterator() + + log.debug("Trying to connect using ${transport.params}") + try { + if (serverLocator != null && !serverLocator.isClosed) { + sessionFactory = serverLocator.createSessionFactory(transport) + } else { + log.warn("Stopping reconnect attempts.") + log.debug("Server locator is closed or garbage collected. Proxy may have been closed during reconnect.") + break + } + } catch (e: ActiveMQException) { + try { + Thread.sleep(retryInterval.toMillis()) + } catch (e: InterruptedException) {} + // could not connect, try with next server transport + reconnectAttempts-- + retryInterval = minOf(maxRetryInterval, retryInterval.times(rpcConfiguration.connectionRetryIntervalMultiplier.toLong())) + continue + } + + log.debug("Connected successfully using ${transport.params}") + log.info("RPC server available.") + sessionFactory!!.addFailoverListener(this::haFailoverHandler) + initSessions() + startSessions() + sendingEnabled.set(true) + break + } + + if (reconnectAttempts == 0 || sessionFactory == null) + log.error("Could not reconnect to the RPC server.") + } + + private fun initSessions() { + producerSession = sessionFactory!!.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + rpcProducer = producerSession!!.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME) + consumerSession = sessionFactory!!.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE) + consumerSession!!.createTemporaryQueue(clientAddress, RoutingType.ANYCAST, clientAddress) + rpcConsumer = consumerSession!!.createConsumer(clientAddress) + rpcConsumer!!.setMessageHandler(this::artemisMessageHandler) + } + + private fun startSessions() { + consumerSession!!.start() + producerSession!!.start() + } + + private fun haFailoverHandler(event: FailoverEventType) { + if (event == FailoverEventType.FAILURE_DETECTED) { + log.warn("Connection failure. Attempting to reconnect using back-up addresses.") + cleanUpOnConnectionLoss() + sessionFactory?.apply { + connection.destroy() + cleanup() + close() + } + haFailoverThread = Thread.currentThread() + attemptReconnect() + } + /* Other events are not considered as reconnection is not done by Artemis */ + } + private fun failoverHandler(event: FailoverEventType) { when (event) { FailoverEventType.FAILURE_DETECTED -> { - sendingEnabled.set(false) - - log.warn("Terminating observables.") - val m = observableContext.observableMap.asMap() - m.keys.forEach { k -> - observationExecutorPool.run(k) { - m[k]?.onError(RPCException("Connection failure detected.")) - } - } - observableContext.observableMap.invalidateAll() - - rpcReplyMap.forEach { _, replyFuture -> - replyFuture.setException(RPCException("Connection failure detected.")) - } - - rpcReplyMap.clear() - callSiteMap?.clear() + cleanUpOnConnectionLoss() } FailoverEventType.FAILOVER_COMPLETED -> { @@ -435,6 +504,25 @@ class RPCClientProxyHandler( } } } + + private fun cleanUpOnConnectionLoss() { + sendingEnabled.set(false) + log.warn("Terminating observables.") + val m = observableContext.observableMap.asMap() + m.keys.forEach { k -> + observationExecutorPool.run(k) { + m[k]?.onError(RPCException("Connection failure detected.")) + } + } + observableContext.observableMap.invalidateAll() + + rpcReplyMap.forEach { _, replyFuture -> + replyFuture.setException(RPCException("Connection failure detected.")) + } + + rpcReplyMap.clear() + callSiteMap?.clear() + } } private typealias RpcObservableMap = Cache>> diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/RPCDriver.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/RPCDriver.kt index 042a38ce67..5e04f66414 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/RPCDriver.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/RPCDriver.kt @@ -75,6 +75,13 @@ inline fun RPCDriverDSL.startRpcClient( configuration: CordaRPCClientConfigurationImpl = CordaRPCClientConfigurationImpl.default ) = startRpcClient(I::class.java, rpcAddress, username, password, configuration) +inline fun RPCDriverDSL.startRpcClient( + haAddressPool: List, + username: String = rpcTestUser.username, + password: String = rpcTestUser.password, + configuration: CordaRPCClientConfigurationImpl = CordaRPCClientConfigurationImpl.default +) = startRpcClient(I::class.java, haAddressPool, username, password, configuration) + data class RpcBrokerHandle( val hostAndPort: NetworkHostAndPort?, /** null if this is an InVM broker */ @@ -336,6 +343,32 @@ data class RPCDriverDSL( } } + /** + * Starts a Netty RPC client. + * + * @param rpcOpsClass The [Class] of the RPC interface. + * @param haAddressPool The addresses of the RPC servers(configured in HA mode) to connect to. + * @param username The username to authenticate with. + * @param password The password to authenticate with. + * @param configuration The RPC client configuration. + */ + fun startRpcClient( + rpcOpsClass: Class, + haAddressPool: List, + username: String = rpcTestUser.username, + password: String = rpcTestUser.password, + configuration: CordaRPCClientConfigurationImpl = CordaRPCClientConfigurationImpl.default + ): CordaFuture { + return driverDSL.executorService.fork { + val client = RPCClient(haAddressPool, null, configuration) + val connection = client.start(rpcOpsClass, username, password, externalTrace) + driverDSL.shutdownManager.registerShutdown { + connection.close() + } + connection.proxy + } + } + /** * Starts a Netty RPC client in a new JVM process that calls random RPCs with random arguments. *