diff --git a/bridge/src/main/kotlin/net/corda/bridge/services/api/BridgeConfiguration.kt b/bridge/src/main/kotlin/net/corda/bridge/services/api/BridgeConfiguration.kt index 7be9e2758f..5238db8528 100644 --- a/bridge/src/main/kotlin/net/corda/bridge/services/api/BridgeConfiguration.kt +++ b/bridge/src/main/kotlin/net/corda/bridge/services/api/BridgeConfiguration.kt @@ -4,6 +4,7 @@ import net.corda.core.identity.CordaX500Name import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.internal.config.NodeSSLConfiguration import net.corda.nodeapi.internal.config.SSLConfiguration +import net.corda.nodeapi.internal.protonwrapper.netty.SocksProxyConfig import java.nio.file.Path enum class BridgeMode { @@ -35,6 +36,8 @@ interface BridgeOutboundConfiguration { val artemisBrokerAddress: NetworkHostAndPort // Allows override of [KeyStore] details for the artemis connection, otherwise the general top level details are used. val customSSLConfiguration: SSLConfiguration? + // Allows use of a SOCKS 4/5 proxy + val socksProxyConfig: SocksProxyConfig? } /** diff --git a/bridge/src/main/kotlin/net/corda/bridge/services/config/BridgeConfigurationImpl.kt b/bridge/src/main/kotlin/net/corda/bridge/services/config/BridgeConfigurationImpl.kt index 4893263907..61ff070d35 100644 --- a/bridge/src/main/kotlin/net/corda/bridge/services/config/BridgeConfigurationImpl.kt +++ b/bridge/src/main/kotlin/net/corda/bridge/services/config/BridgeConfigurationImpl.kt @@ -7,6 +7,7 @@ import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.internal.ArtemisMessagingComponent import net.corda.nodeapi.internal.config.SSLConfiguration import net.corda.nodeapi.internal.config.parseAs +import net.corda.nodeapi.internal.protonwrapper.netty.SocksProxyConfig import java.nio.file.Path @@ -17,7 +18,8 @@ data class CustomSSLConfiguration(override val keyStorePassword: String, override val certificatesDirectory: Path) : SSLConfiguration data class BridgeOutboundConfigurationImpl(override val artemisBrokerAddress: NetworkHostAndPort, - override val customSSLConfiguration: CustomSSLConfiguration?) : BridgeOutboundConfiguration + override val customSSLConfiguration: CustomSSLConfiguration?, + override val socksProxyConfig: SocksProxyConfig? = null) : BridgeOutboundConfiguration data class BridgeInboundConfigurationImpl(override val listeningAddress: NetworkHostAndPort, override val customSSLConfiguration: CustomSSLConfiguration?) : BridgeInboundConfiguration diff --git a/bridge/src/main/kotlin/net/corda/bridge/services/sender/DirectBridgeSenderService.kt b/bridge/src/main/kotlin/net/corda/bridge/services/sender/DirectBridgeSenderService.kt index 3409dfa80d..7efd9fa481 100644 --- a/bridge/src/main/kotlin/net/corda/bridge/services/sender/DirectBridgeSenderService.kt +++ b/bridge/src/main/kotlin/net/corda/bridge/services/sender/DirectBridgeSenderService.kt @@ -22,7 +22,7 @@ class DirectBridgeSenderService(val conf: BridgeConfiguration, private val statusFollower: ServiceStateCombiner private var statusSubscriber: Subscription? = null private var connectionSubscriber: Subscription? = null - private var bridgeControlListener: BridgeControlListener = BridgeControlListener(conf, { ForwardingArtemisMessageClient(artemisConnectionService) }) + private var bridgeControlListener: BridgeControlListener = BridgeControlListener(conf, conf.outboundConfig!!.socksProxyConfig, { ForwardingArtemisMessageClient(artemisConnectionService) }) init { statusFollower = ServiceStateCombiner(listOf(auditService, artemisConnectionService, haService)) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt index 5dc5840ceb..31a9673a46 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt @@ -28,6 +28,7 @@ import net.corda.nodeapi.internal.bridging.AMQPBridgeManager.AMQPBridge.Companio import net.corda.nodeapi.internal.config.NodeSSLConfiguration import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient +import net.corda.nodeapi.internal.protonwrapper.netty.SocksProxyConfig import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE import org.apache.activemq.artemis.api.core.client.ClientConsumer @@ -47,7 +48,7 @@ import kotlin.concurrent.withLock * The Netty thread pool used by the AMQPBridges is also shared and managed by the AMQPBridgeManager. */ @VisibleForTesting -class AMQPBridgeManager(config: NodeSSLConfiguration, val artemisMessageClientFactory: () -> ArtemisSessionProvider) : BridgeManager { +class AMQPBridgeManager(config: NodeSSLConfiguration, private val socksProxyConfig: SocksProxyConfig? = null, val artemisMessageClientFactory: () -> ArtemisSessionProvider) : BridgeManager { private val lock = ReentrantLock() private val bridgeNameToBridgeMap = mutableMapOf() @@ -57,7 +58,7 @@ class AMQPBridgeManager(config: NodeSSLConfiguration, val artemisMessageClientFa private val trustStore = config.loadTrustStore().internal private var artemis: ArtemisSessionProvider? = null - constructor(config: NodeSSLConfiguration, p2pAddress: NetworkHostAndPort, maxMessageSize: Int) : this(config, { ArtemisMessagingClient(config, p2pAddress, maxMessageSize) }) + constructor(config: NodeSSLConfiguration, p2pAddress: NetworkHostAndPort, maxMessageSize: Int, socksProxyConfig: SocksProxyConfig? = null) : this(config, socksProxyConfig, { ArtemisMessagingClient(config, p2pAddress, maxMessageSize) }) companion object { private const val NUM_BRIDGE_THREADS = 0 // Default sized pool @@ -78,6 +79,7 @@ class AMQPBridgeManager(config: NodeSSLConfiguration, val artemisMessageClientFa keyStorePrivateKeyPassword: String, trustStore: KeyStore, sharedEventGroup: EventLoopGroup, + socksProxyConfig: SocksProxyConfig?, private val artemis: ArtemisSessionProvider) { companion object { fun getBridgeName(queueName: String, hostAndPort: NetworkHostAndPort): String = "$queueName -> $hostAndPort" @@ -85,7 +87,7 @@ class AMQPBridgeManager(config: NodeSSLConfiguration, val artemisMessageClientFa private val log = LoggerFactory.getLogger("$bridgeName:${legalNames.first()}") - val amqpClient = AMQPClient(listOf(target), legalNames, PEER_USER, PEER_USER, keyStore, keyStorePrivateKeyPassword, trustStore, sharedThreadPool = sharedEventGroup) + val amqpClient = AMQPClient(listOf(target), legalNames, PEER_USER, PEER_USER, keyStore, keyStorePrivateKeyPassword, trustStore, sharedThreadPool = sharedEventGroup, socksProxyConfig = socksProxyConfig) val bridgeName: String get() = getBridgeName(queueName, target) private val lock = ReentrantLock() // lock to serialise session level access private var session: ClientSession? = null @@ -179,7 +181,7 @@ class AMQPBridgeManager(config: NodeSSLConfiguration, val artemisMessageClientFa if (bridgeExists(getBridgeName(queueName, target))) { return } - val newBridge = AMQPBridge(queueName, target, legalNames, keyStore, keyStorePrivateKeyPassword, trustStore, sharedEventLoopGroup!!, artemis!!) + val newBridge = AMQPBridge(queueName, target, legalNames, keyStore, keyStorePrivateKeyPassword, trustStore, sharedEventLoopGroup!!, socksProxyConfig, artemis!!) lock.withLock { bridgeNameToBridgeMap[newBridge.bridgeName] = newBridge } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt index 7145800cbd..680e56201d 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt @@ -22,6 +22,7 @@ import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEERS_PREFIX import net.corda.nodeapi.internal.ArtemisSessionProvider import net.corda.nodeapi.internal.config.NodeSSLConfiguration +import net.corda.nodeapi.internal.protonwrapper.netty.SocksProxyConfig import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.client.ClientConsumer @@ -29,16 +30,18 @@ import org.apache.activemq.artemis.api.core.client.ClientMessage import java.util.* class BridgeControlListener(val config: NodeSSLConfiguration, + socksProxyConfig: SocksProxyConfig? = null, val artemisMessageClientFactory: () -> ArtemisSessionProvider) : AutoCloseable { private val bridgeId: String = UUID.randomUUID().toString() - private val bridgeManager: BridgeManager = AMQPBridgeManager(config, artemisMessageClientFactory) + private val bridgeManager: BridgeManager = AMQPBridgeManager(config, socksProxyConfig, artemisMessageClientFactory) private val validInboundQueues = mutableSetOf() private var artemis: ArtemisSessionProvider? = null private var controlConsumer: ClientConsumer? = null constructor(config: NodeSSLConfiguration, p2pAddress: NetworkHostAndPort, - maxMessageSize: Int) : this(config, { ArtemisMessagingClient(config, p2pAddress, maxMessageSize) }) + maxMessageSize: Int, + socksProxy: SocksProxyConfig? = null) : this(config, socksProxy, { ArtemisMessagingClient(config, p2pAddress, maxMessageSize) }) companion object { private val log = contextLogger() diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt index 38085d92fb..2c3f3a2b40 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPChannelHandler.kt @@ -15,6 +15,8 @@ import io.netty.channel.ChannelDuplexHandler import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelPromise import io.netty.channel.socket.SocketChannel +import io.netty.handler.proxy.ProxyConnectException +import io.netty.handler.proxy.ProxyConnectionEvent import io.netty.handler.ssl.SslHandler import io.netty.handler.ssl.SslHandshakeCompletionEvent import io.netty.util.ReferenceCountUtil @@ -51,6 +53,7 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, private var localCert: X509Certificate? = null private var remoteCert: X509Certificate? = null private var eventProcessor: EventProcessor? = null + private var suppressClose: Boolean = false override fun channelActive(ctx: ChannelHandlerContext) { val ch = ctx.channel() @@ -82,12 +85,17 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, override fun channelInactive(ctx: ChannelHandlerContext) { val ch = ctx.channel() log.info("Closed client connection ${ch.id()} from $remoteAddress to ${ch.localAddress()}") - onClose(Pair(ch as SocketChannel, ConnectionChange(remoteAddress, remoteCert, false))) + if (!suppressClose) { + onClose(Pair(ch as SocketChannel, ConnectionChange(remoteAddress, remoteCert, false))) + } eventProcessor?.close() ctx.fireChannelInactive() } override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { + if (evt is ProxyConnectionEvent) { + remoteAddress = evt.destinationAddress() // update address to teh real target address + } if (evt is SslHandshakeCompletionEvent) { if (evt.isSuccess) { val sslHandler = ctx.pipeline().get(SslHandler::class.java) @@ -111,6 +119,15 @@ internal class AMQPChannelHandler(private val serverMode: Boolean, } } + + @Suppress("OverridingDeprecatedMember") + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { + if (cause is ProxyConnectException) { + log.warn("Proxy connection failed ${cause.message}") + suppressClose = true // The pipeline gets marked as active on connection to the proxy rather than to the target, which causes excess close events + } + } + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { try { log.debug { "Received $msg" } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt index 0148d8cedb..f927ce7b2d 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt @@ -17,6 +17,8 @@ import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioSocketChannel import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LoggingHandler +import io.netty.handler.proxy.Socks4ProxyHandler +import io.netty.handler.proxy.Socks5ProxyHandler import io.netty.util.internal.logging.InternalLoggerFactory import io.netty.util.internal.logging.Slf4JLoggerFactory import net.corda.core.identity.CordaX500Name @@ -27,6 +29,7 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import rx.Observable import rx.subjects.PublishSubject +import java.net.InetSocketAddress import java.security.KeyStore import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock @@ -34,6 +37,19 @@ import javax.net.ssl.KeyManagerFactory import javax.net.ssl.TrustManagerFactory import kotlin.concurrent.withLock +enum class SocksProxyVersion { + SOCKS4, + SOCKS5 +} + +data class SocksProxyConfig(val version: SocksProxyVersion, val proxyAddress: NetworkHostAndPort, val userName: String? = null, val password: String? = null) { + init { + if (version == SocksProxyVersion.SOCKS4) { + require(password == null) { "SOCKS4 does not support a password" } + } + } +} + /** * The AMQPClient creates a connection initiator that will try to connect in a round-robin fashion * to the first open SSL socket. It will keep retrying until it is stopped. @@ -49,7 +65,8 @@ class AMQPClient(val targets: List, private val keyStorePrivateKeyPassword: String, private val trustStore: KeyStore, private val trace: Boolean = false, - private val sharedThreadPool: EventLoopGroup? = null) : AutoCloseable { + private val sharedThreadPool: EventLoopGroup? = null, + private val socksProxyConfig: SocksProxyConfig? = null) : AutoCloseable { companion object { init { InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) @@ -117,6 +134,25 @@ class AMQPClient(val targets: List, override fun initChannel(ch: SocketChannel) { val pipeline = ch.pipeline() + val socksConfig = parent.socksProxyConfig + if (socksConfig != null) { + val proxyAddress = InetSocketAddress(socksConfig.proxyAddress.host, socksConfig.proxyAddress.port) + val proxy = when (parent.socksProxyConfig!!.version) { + SocksProxyVersion.SOCKS4 -> { + Socks4ProxyHandler(proxyAddress, socksConfig.userName) + } + SocksProxyVersion.SOCKS5 -> { + Socks5ProxyHandler(proxyAddress, socksConfig.userName, socksConfig.password) + } + } + pipeline.addLast("SocksPoxy", proxy) + proxy.connectFuture().addListener { + if (!it.isSuccess) { + ch.disconnect() + } + } + } + val handler = createClientSslHelper(parent.currentTarget, keyManagerFactory, trustManagerFactory) pipeline.addLast("sslHandler", handler) if (parent.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO)) diff --git a/node/build.gradle b/node/build.gradle index 011bfb8d9f..1a6450d512 100644 --- a/node/build.gradle +++ b/node/build.gradle @@ -229,6 +229,11 @@ dependencies { // Jolokia JVM monitoring agent, required to push logs through slf4j compile "org.jolokia:jolokia-jvm:${jolokia_version}:agent" + + // Allow access to simple SOCKS Server for integration testing + testCompile('io.netty:netty-example:4.1.9.Final') { + exclude group: "io.netty", module: "netty-tcnative" + } } task integrationTest(type: Test) { diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/SocksTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/SocksTests.kt new file mode 100644 index 0000000000..73603bced2 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/SocksTests.kt @@ -0,0 +1,362 @@ +/* + * R3 Proprietary and Confidential + * + * Copyright (c) 2018 R3 Limited. All rights reserved. + * + * The intellectual and technical concepts contained herein are proprietary to R3 and its suppliers and are protected by trade secret law. + * + * Distribution of this file or any portion thereof via any medium without the express permission of R3 is strictly prohibited. + */ + +package net.corda.node.amqp + +import com.nhaarman.mockito_kotlin.doReturn +import com.nhaarman.mockito_kotlin.whenever +import io.netty.bootstrap.ServerBootstrap +import io.netty.channel.ChannelFuture +import io.netty.channel.EventLoopGroup +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.nio.NioServerSocketChannel +import io.netty.example.socksproxy.SocksServerInitializer +import io.netty.handler.logging.LogLevel +import io.netty.handler.logging.LoggingHandler +import net.corda.core.identity.CordaX500Name +import net.corda.core.internal.div +import net.corda.core.toFuture +import net.corda.core.utilities.NetworkHostAndPort +import net.corda.node.services.config.* +import net.corda.node.services.messaging.ArtemisMessagingServer +import net.corda.nodeapi.internal.ArtemisMessagingClient +import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX +import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEER_USER +import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus +import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient +import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer +import net.corda.nodeapi.internal.protonwrapper.netty.SocksProxyConfig +import net.corda.nodeapi.internal.protonwrapper.netty.SocksProxyVersion +import net.corda.testing.core.* +import net.corda.testing.internal.rigorousMock +import org.apache.activemq.artemis.api.core.RoutingType +import org.junit.After +import org.junit.Assert.assertArrayEquals +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.junit.rules.TemporaryFolder +import kotlin.test.assertEquals + +class SocksTests { + @Rule + @JvmField + val temporaryFolder = TemporaryFolder() + + private val socksPort = freePort() + private val serverPort = freePort() + private val serverPort2 = freePort() + private val artemisPort = freePort() + + private abstract class AbstractNodeConfiguration : NodeConfiguration + + private class SocksServer(val port: Int) { + private val bossGroup = NioEventLoopGroup(1) + private val workerGroup = NioEventLoopGroup() + private var closeFuture: ChannelFuture? = null + + init { + try { + val b = ServerBootstrap() + b.group(bossGroup, workerGroup) + .channel(NioServerSocketChannel::class.java) + .handler(LoggingHandler(LogLevel.INFO)) + .childHandler(SocksServerInitializer()) + closeFuture = b.bind(port).sync().channel().closeFuture() + } catch (ex: Exception) { + bossGroup.shutdownGracefully() + workerGroup.shutdownGracefully() + } + } + + fun close() { + bossGroup.shutdownGracefully() + workerGroup.shutdownGracefully() + closeFuture?.sync() + } + } + + private var socksProxy: SocksServer? = null + + @Before + fun setup() { + socksProxy = SocksServer(socksPort) + } + + @After + fun shutdown() { + socksProxy?.close() + socksProxy = null + } + + @Test + fun `Simple AMPQ Client to Server`() { + val amqpServer = createServer(serverPort) + amqpServer.use { + amqpServer.start() + val receiveSubs = amqpServer.onReceive.subscribe { + assertEquals(BOB_NAME.toString(), it.sourceLegalName) + assertEquals(P2P_PREFIX + "Test", it.topic) + assertEquals("Test", String(it.payload)) + it.complete(true) + } + val amqpClient = createClient() + amqpClient.use { + val serverConnected = amqpServer.onConnection.toFuture() + val clientConnected = amqpClient.onConnection.toFuture() + amqpClient.start() + val serverConnect = serverConnected.get() + assertEquals(true, serverConnect.connected) + assertEquals(BOB_NAME, CordaX500Name.build(serverConnect.remoteCert!!.subjectX500Principal)) + val clientConnect = clientConnected.get() + assertEquals(true, clientConnect.connected) + assertEquals(ALICE_NAME, CordaX500Name.build(clientConnect.remoteCert!!.subjectX500Principal)) + val msg = amqpClient.createMessage("Test".toByteArray(), + P2P_PREFIX + "Test", + ALICE_NAME.toString(), + emptyMap()) + amqpClient.write(msg) + assertEquals(MessageStatus.Acknowledged, msg.onComplete.get()) + receiveSubs.unsubscribe() + } + } + } + + @Test + fun `AMPQ Client refuses to connect to unexpected server`() { + val amqpServer = createServer(serverPort, CordaX500Name("Rogue 1", "London", "GB")) + amqpServer.use { + amqpServer.start() + val amqpClient = createClient() + amqpClient.use { + val clientConnected = amqpClient.onConnection.toFuture() + amqpClient.start() + val clientConnect = clientConnected.get() + assertEquals(false, clientConnect.connected) + } + } + } + + @Test + fun `Client Failover for multiple IP`() { + val amqpServer = createServer(serverPort) + val amqpServer2 = createServer(serverPort2) + val amqpClient = createClient() + try { + val serverConnected = amqpServer.onConnection.toFuture() + val serverConnected2 = amqpServer2.onConnection.toFuture() + val clientConnected = amqpClient.onConnection.toBlocking().iterator + amqpServer.start() + amqpClient.start() + val serverConn1 = serverConnected.get() + assertEquals(true, serverConn1.connected) + assertEquals(BOB_NAME, CordaX500Name.build(serverConn1.remoteCert!!.subjectX500Principal)) + val connState1 = clientConnected.next() + assertEquals(true, connState1.connected) + assertEquals(ALICE_NAME, CordaX500Name.build(connState1.remoteCert!!.subjectX500Principal)) + assertEquals(serverPort, connState1.remoteAddress.port) + + // Fail over + amqpServer2.start() + amqpServer.stop() + val connState2 = clientConnected.next() + assertEquals(false, connState2.connected) + assertEquals(serverPort, connState2.remoteAddress.port) + val serverConn2 = serverConnected2.get() + assertEquals(true, serverConn2.connected) + assertEquals(BOB_NAME, CordaX500Name.build(serverConn2.remoteCert!!.subjectX500Principal)) + val connState3 = clientConnected.next() + assertEquals(true, connState3.connected) + assertEquals(ALICE_NAME, CordaX500Name.build(connState3.remoteCert!!.subjectX500Principal)) + assertEquals(serverPort2, connState3.remoteAddress.port) + + // Fail back + amqpServer.start() + amqpServer2.stop() + val connState4 = clientConnected.next() + assertEquals(false, connState4.connected) + assertEquals(serverPort2, connState4.remoteAddress.port) + val serverConn3 = serverConnected.get() + assertEquals(true, serverConn3.connected) + assertEquals(BOB_NAME, CordaX500Name.build(serverConn3.remoteCert!!.subjectX500Principal)) + val connState5 = clientConnected.next() + assertEquals(true, connState5.connected) + assertEquals(ALICE_NAME, CordaX500Name.build(connState5.remoteCert!!.subjectX500Principal)) + assertEquals(serverPort, connState5.remoteAddress.port) + } finally { + amqpClient.close() + amqpServer.close() + amqpServer2.close() + } + } + + @Test + fun `Send a message from AMQP to Artemis inbox`() { + val (server, artemisClient) = createArtemisServerAndClient() + val amqpClient = createClient() + val clientConnected = amqpClient.onConnection.toFuture() + amqpClient.start() + assertEquals(true, clientConnected.get().connected) + assertEquals(CHARLIE_NAME, CordaX500Name.build(clientConnected.get().remoteCert!!.subjectX500Principal)) + val artemis = artemisClient.started!! + val sendAddress = P2P_PREFIX + "Test" + artemis.session.createQueue(sendAddress, RoutingType.ANYCAST, "queue", true) + val consumer = artemis.session.createConsumer("queue") + val testData = "Test".toByteArray() + val testProperty = mutableMapOf() + testProperty["TestProp"] = "1" + val message = amqpClient.createMessage(testData, sendAddress, CHARLIE_NAME.toString(), testProperty) + amqpClient.write(message) + assertEquals(MessageStatus.Acknowledged, message.onComplete.get()) + val received = consumer.receive() + assertEquals("1", received.getStringProperty("TestProp")) + assertArrayEquals(testData, ByteArray(received.bodySize).apply { received.bodyBuffer.readBytes(this) }) + amqpClient.stop() + artemisClient.stop() + server.stop() + } + + @Test + fun `shared AMQPClient threadpool tests`() { + val amqpServer = createServer(serverPort) + amqpServer.use { + val connectionEvents = amqpServer.onConnection.toBlocking().iterator + amqpServer.start() + val sharedThreads = NioEventLoopGroup() + val amqpClient1 = createSharedThreadsClient(sharedThreads, 0) + val amqpClient2 = createSharedThreadsClient(sharedThreads, 1) + amqpClient1.start() + val connection1 = connectionEvents.next() + assertEquals(true, connection1.connected) + val connection1ID = CordaX500Name.build(connection1.remoteCert!!.subjectX500Principal) + assertEquals("client 0", connection1ID.organisationUnit) + val source1 = connection1.remoteAddress + amqpClient2.start() + val connection2 = connectionEvents.next() + assertEquals(true, connection2.connected) + val connection2ID = CordaX500Name.build(connection2.remoteCert!!.subjectX500Principal) + assertEquals("client 1", connection2ID.organisationUnit) + val source2 = connection2.remoteAddress + // Stopping one shouldn't disconnect the other + amqpClient1.stop() + val connection3 = connectionEvents.next() + assertEquals(false, connection3.connected) + assertEquals(source1, connection3.remoteAddress) + assertEquals(false, amqpClient1.connected) + assertEquals(true, amqpClient2.connected) + // Now shutdown both + amqpClient2.stop() + val connection4 = connectionEvents.next() + assertEquals(false, connection4.connected) + assertEquals(source2, connection4.remoteAddress) + assertEquals(false, amqpClient1.connected) + assertEquals(false, amqpClient2.connected) + // Now restarting one should work + amqpClient1.start() + val connection5 = connectionEvents.next() + assertEquals(true, connection5.connected) + val connection5ID = CordaX500Name.build(connection5.remoteCert!!.subjectX500Principal) + assertEquals("client 0", connection5ID.organisationUnit) + assertEquals(true, amqpClient1.connected) + assertEquals(false, amqpClient2.connected) + // Cleanup + amqpClient1.stop() + sharedThreads.shutdownGracefully() + sharedThreads.terminationFuture().sync() + } + } + + private fun createArtemisServerAndClient(): Pair { + val artemisConfig = rigorousMock().also { + doReturn(temporaryFolder.root.toPath() / "artemis").whenever(it).baseDirectory + doReturn(CHARLIE_NAME).whenever(it).myLegalName + doReturn("trustpass").whenever(it).trustStorePassword + doReturn("cordacadevpass").whenever(it).keyStorePassword + doReturn(NetworkHostAndPort("0.0.0.0", artemisPort)).whenever(it).p2pAddress + doReturn(null).whenever(it).jmxMonitoringHttpPort + doReturn(emptyList()).whenever(it).certificateChainCheckPolicies + doReturn(EnterpriseConfiguration(MutualExclusionConfiguration(false, "", 20000, 40000))).whenever(it).enterpriseConfiguration + } + artemisConfig.configureWithDevSSLCertificate() + + val server = ArtemisMessagingServer(artemisConfig, artemisPort, MAX_MESSAGE_SIZE) + val client = ArtemisMessagingClient(artemisConfig, NetworkHostAndPort("localhost", artemisPort), MAX_MESSAGE_SIZE) + server.start() + client.start() + return Pair(server, client) + } + + private fun createClient(): AMQPClient { + val clientConfig = rigorousMock().also { + doReturn(temporaryFolder.root.toPath() / "client").whenever(it).baseDirectory + doReturn(BOB_NAME).whenever(it).myLegalName + doReturn("trustpass").whenever(it).trustStorePassword + doReturn("cordacadevpass").whenever(it).keyStorePassword + } + clientConfig.configureWithDevSSLCertificate() + + val clientTruststore = clientConfig.loadTrustStore().internal + val clientKeystore = clientConfig.loadSslKeyStore().internal + return AMQPClient( + listOf(NetworkHostAndPort("localhost", serverPort), + NetworkHostAndPort("localhost", serverPort2), + NetworkHostAndPort("localhost", artemisPort)), + setOf(ALICE_NAME, CHARLIE_NAME), + PEER_USER, + PEER_USER, + clientKeystore, + clientConfig.keyStorePassword, + clientTruststore, true, + socksProxyConfig = SocksProxyConfig(SocksProxyVersion.SOCKS5, NetworkHostAndPort("127.0.0.1", socksPort), null, null)) + } + + private fun createSharedThreadsClient(sharedEventGroup: EventLoopGroup, id: Int): AMQPClient { + val clientConfig = rigorousMock().also { + doReturn(temporaryFolder.root.toPath() / "client_%$id").whenever(it).baseDirectory + doReturn(CordaX500Name(null, "client $id", "Corda", "London", null, "GB")).whenever(it).myLegalName + doReturn("trustpass").whenever(it).trustStorePassword + doReturn("cordacadevpass").whenever(it).keyStorePassword + } + clientConfig.configureWithDevSSLCertificate() + + val clientTruststore = clientConfig.loadTrustStore().internal + val clientKeystore = clientConfig.loadSslKeyStore().internal + return AMQPClient( + listOf(NetworkHostAndPort("localhost", serverPort)), + setOf(ALICE_NAME), + PEER_USER, + PEER_USER, + clientKeystore, + clientConfig.keyStorePassword, + clientTruststore, true, sharedEventGroup, + socksProxyConfig = SocksProxyConfig(SocksProxyVersion.SOCKS5, NetworkHostAndPort("127.0.0.1", socksPort), null, null)) + } + + private fun createServer(port: Int, name: CordaX500Name = ALICE_NAME): AMQPServer { + val serverConfig = rigorousMock().also { + doReturn(temporaryFolder.root.toPath() / "server").whenever(it).baseDirectory + doReturn(name).whenever(it).myLegalName + doReturn("trustpass").whenever(it).trustStorePassword + doReturn("cordacadevpass").whenever(it).keyStorePassword + } + serverConfig.configureWithDevSSLCertificate() + + val serverTruststore = serverConfig.loadTrustStore().internal + val serverKeystore = serverConfig.loadSslKeyStore().internal + return AMQPServer( + "0.0.0.0", + port, + PEER_USER, + PEER_USER, + serverKeystore, + serverConfig.keyStorePassword, + serverTruststore) + } +}