Refactor location of bridge code to allow out of process bridging (#2431)

Fix some issues Andras has seen
This commit is contained in:
Matthew Nesbit
2018-01-30 16:29:59 +00:00
committed by GitHub
parent 2d557d04b4
commit ceff50d656
25 changed files with 75 additions and 68 deletions

View File

@ -0,0 +1,61 @@
package net.corda.nodeapi.internal
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.loggerFor
import net.corda.nodeapi.ArtemisTcpTransport
import net.corda.nodeapi.ConnectionDirection
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_USER
import net.corda.nodeapi.internal.config.SSLConfiguration
import org.apache.activemq.artemis.api.core.client.ActiveMQClient
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
import org.apache.activemq.artemis.api.core.client.ClientProducer
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.apache.activemq.artemis.api.core.client.ClientSessionFactory
class ArtemisMessagingClient(private val config: SSLConfiguration, private val serverAddress: NetworkHostAndPort, private val maxMessageSize: Int) {
companion object {
private val log = loggerFor<ArtemisMessagingClient>()
}
class Started(val sessionFactory: ClientSessionFactory, val session: ClientSession, val producer: ClientProducer)
var started: Started? = null
private set
fun start(): Started = synchronized(this) {
check(started == null) { "start can't be called twice" }
log.info("Connecting to message broker: $serverAddress")
// TODO Add broker CN to config for host verification in case the embedded broker isn't used
val tcpTransport = ArtemisTcpTransport.tcpTransport(ConnectionDirection.Outbound(), serverAddress, config)
val locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply {
// Never time out on our loopback Artemis connections. If we switch back to using the InVM transport this
// would be the default and the two lines below can be deleted.
connectionTTL = -1
clientFailureCheckPeriod = -1
minLargeMessageSize = maxMessageSize
isUseGlobalPools = nodeSerializationEnv != null
}
val sessionFactory = locator.createSessionFactory()
// Login using the node username. The broker will authenticate us as its node (as opposed to another peer)
// using our TLS certificate.
// Note that the acknowledgement of messages is not flushed to the Artermis journal until the default buffer
// size of 1MB is acknowledged.
val session = sessionFactory!!.createSession(NODE_USER, NODE_USER, false, true, true, locator.isPreAcknowledge, DEFAULT_ACK_BATCH_SIZE)
session.start()
// Create a general purpose producer.
val producer = session.createProducer()
return Started(sessionFactory, session, producer).also { started = it }
}
fun stop() = synchronized(this) {
started?.run {
producer.close()
// Ensure any trailing messages are committed to the journal
session.commit()
// Closing the factory closes all the sessions it produced as well.
sessionFactory.close()
}
started = null
}
}

View File

@ -0,0 +1,213 @@
package net.corda.nodeapi.internal.bridging
import io.netty.channel.EventLoopGroup
import io.netty.channel.nio.NioEventLoopGroup
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.VisibleForTesting
import net.corda.core.node.NodeInfo
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.debug
import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_USER
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEER_USER
import net.corda.nodeapi.internal.ArtemisMessagingComponent.RemoteInboxAddress.Companion.translateLocalQueueToInboxAddress
import net.corda.nodeapi.internal.bridging.AMQPBridgeManager.AMQPBridge.Companion.getBridgeName
import net.corda.nodeapi.internal.config.NodeSSLConfiguration
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
import org.apache.activemq.artemis.api.core.client.ClientConsumer
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.slf4j.LoggerFactory
import rx.Subscription
import java.security.KeyStore
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
/**
* The AMQPBridgeManager holds the list of independent AMQPBridge objects that actively ferry messages to remote Artemis
* inboxes.
* The AMQPBridgeManager also provides a single shared connection to Artemis, although each bridge then creates an
* independent Session for message consumption.
* The Netty thread pool used by the AMQPBridges is also shared and managed by the AMQPBridgeManager.
*/
@VisibleForTesting
class AMQPBridgeManager(val config: NodeSSLConfiguration, val p2pAddress: NetworkHostAndPort, val maxMessageSize: Int) : BridgeManager {
private val lock = ReentrantLock()
private val bridgeNameToBridgeMap = mutableMapOf<String, AMQPBridge>()
private var sharedEventLoopGroup: EventLoopGroup? = null
private val keyStore = config.loadSslKeyStore().internal
private val keyStorePrivateKeyPassword: String = config.keyStorePassword
private val trustStore = config.loadTrustStore().internal
private var artemis: ArtemisMessagingClient? = null
companion object {
private const val NUM_BRIDGE_THREADS = 0 // Default sized pool
}
/**
* Each AMQPBridge is an independent consumer of messages from the Artemis local queue per designated endpoint.
* It attempts to deliver these messages via an AMQPClient instance to the remote Artemis inbox.
* To prevent race conditions the Artemis session/consumer is only created when the AMQPClient has a stable AMQP connection.
* The acknowledgement and removal of messages from the local queue only occurs if there successful end-to-end delivery.
* If the delivery fails the session is rolled back to prevent loss of the message. This may cause duplicate delivery,
* however Artemis and the remote Corda instanced will deduplicate these messages.
*/
private class AMQPBridge(private val queueName: String,
private val target: NetworkHostAndPort,
private val legalNames: Set<CordaX500Name>,
keyStore: KeyStore,
keyStorePrivateKeyPassword: String,
trustStore: KeyStore,
sharedEventGroup: EventLoopGroup,
private val artemis: ArtemisMessagingClient) {
companion object {
fun getBridgeName(queueName: String, hostAndPort: NetworkHostAndPort): String = "$queueName -> $hostAndPort"
}
private val log = LoggerFactory.getLogger("$bridgeName:${legalNames.first()}")
val amqpClient = AMQPClient(listOf(target), legalNames, PEER_USER, PEER_USER, keyStore, keyStorePrivateKeyPassword, trustStore, sharedThreadPool = sharedEventGroup)
val bridgeName: String get() = getBridgeName(queueName, target)
private val lock = ReentrantLock() // lock to serialise session level access
private var session: ClientSession? = null
private var consumer: ClientConsumer? = null
private var connectedSubscription: Subscription? = null
fun start() {
log.info("Create new AMQP bridge")
connectedSubscription = amqpClient.onConnection.subscribe({ x -> onSocketConnected(x.connected) })
amqpClient.start()
}
fun stop() {
log.info("Stopping AMQP bridge")
lock.withLock {
synchronized(artemis) {
consumer?.close()
consumer = null
session?.stop()
session = null
}
}
amqpClient.stop()
connectedSubscription?.unsubscribe()
connectedSubscription = null
}
private fun onSocketConnected(connected: Boolean) {
lock.withLock {
synchronized(artemis) {
if (connected) {
log.info("Bridge Connected")
val sessionFactory = artemis.started!!.sessionFactory
val session = sessionFactory.createSession(NODE_USER, NODE_USER, false, false, false, false, DEFAULT_ACK_BATCH_SIZE)
this.session = session
val consumer = session.createConsumer(queueName)
this.consumer = consumer
consumer.setMessageHandler(this@AMQPBridge::clientArtemisMessageHandler)
session.start()
} else {
log.info("Bridge Disconnected")
consumer?.close()
consumer = null
session?.stop()
session = null
}
}
}
}
private fun clientArtemisMessageHandler(artemisMessage: ClientMessage) {
lock.withLock {
val data = ByteArray(artemisMessage.bodySize).apply { artemisMessage.bodyBuffer.readBytes(this) }
val properties = HashMap<Any?, Any?>()
for (key in artemisMessage.propertyNames) {
var value = artemisMessage.getObjectProperty(key)
if (value is SimpleString) {
value = value.toString()
}
properties[key.toString()] = value
}
log.debug { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty("_AMQ_DUPL_ID")}" }
val peerInbox = translateLocalQueueToInboxAddress(queueName)
val sendableMessage = amqpClient.createMessage(data, peerInbox,
legalNames.first().toString(),
properties)
sendableMessage.onComplete.then {
log.debug { "Bridge ACK ${sendableMessage.onComplete.get()}" }
lock.withLock {
if (sendableMessage.onComplete.get() == MessageStatus.Acknowledged) {
artemisMessage.acknowledge()
session?.commit()
} else {
log.info("Rollback rejected message uuid: ${artemisMessage.getObjectProperty("_AMQ_DUPL_ID")}")
session?.rollback(false)
}
}
}
amqpClient.write(sendableMessage)
}
}
}
private fun gatherAddresses(node: NodeInfo): Sequence<ArtemisMessagingComponent.ArtemisPeerAddress> {
val address = node.addresses.single()
return node.legalIdentitiesAndCerts.map { ArtemisMessagingComponent.NodeAddress(it.party.owningKey, address) }.asSequence()
}
override fun deployBridge(queueName: String, target: NetworkHostAndPort, legalNames: Set<CordaX500Name>) {
if (bridgeExists(getBridgeName(queueName, target))) {
return
}
val newBridge = AMQPBridge(queueName, target, legalNames, keyStore, keyStorePrivateKeyPassword, trustStore, sharedEventLoopGroup!!, artemis!!)
lock.withLock {
bridgeNameToBridgeMap[newBridge.bridgeName] = newBridge
}
newBridge.start()
}
override fun destroyBridges(node: NodeInfo) {
lock.withLock {
gatherAddresses(node).forEach {
val bridge = bridgeNameToBridgeMap.remove(getBridgeName(it.queueName, it.hostAndPort))
bridge?.stop()
}
}
}
override fun destroyBridge(queueName: String, hostAndPort: NetworkHostAndPort) {
lock.withLock {
val bridge = bridgeNameToBridgeMap.remove(getBridgeName(queueName, hostAndPort))
bridge?.stop()
}
}
override fun bridgeExists(bridgeName: String): Boolean = lock.withLock { bridgeNameToBridgeMap.containsKey(bridgeName) }
override fun start() {
sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS)
val artemis = ArtemisMessagingClient(config, p2pAddress, maxMessageSize)
this.artemis = artemis
artemis.start()
}
override fun stop() = close()
override fun close() {
lock.withLock {
for (bridge in bridgeNameToBridgeMap.values) {
bridge.stop()
}
sharedEventLoopGroup?.shutdownGracefully()
sharedEventLoopGroup?.terminationFuture()?.sync()
sharedEventLoopGroup = null
bridgeNameToBridgeMap.clear()
artemis?.stop()
}
}
}

View File

@ -0,0 +1,116 @@
package net.corda.nodeapi.internal.bridging
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_CONTROL
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_NOTIFY
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEERS_PREFIX
import net.corda.nodeapi.internal.config.NodeSSLConfiguration
import org.apache.activemq.artemis.api.core.RoutingType
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ClientConsumer
import org.apache.activemq.artemis.api.core.client.ClientMessage
import java.util.*
class BridgeControlListener(val config: NodeSSLConfiguration,
val p2pAddress: NetworkHostAndPort,
val maxMessageSize: Int) : AutoCloseable {
private val bridgeId: String = UUID.randomUUID().toString()
private val bridgeManager: BridgeManager = AMQPBridgeManager(config, p2pAddress, maxMessageSize)
private val validInboundQueues = mutableSetOf<String>()
private var artemis: ArtemisMessagingClient? = null
private var controlConsumer: ClientConsumer? = null
companion object {
private val log = contextLogger()
}
fun start() {
stop()
bridgeManager.start()
val artemis = ArtemisMessagingClient(config, p2pAddress, maxMessageSize)
this.artemis = artemis
artemis.start()
val artemisClient = artemis.started!!
val artemisSession = artemisClient.session
val bridgeControlQueue = "$BRIDGE_CONTROL.$bridgeId"
artemisSession.createTemporaryQueue(BRIDGE_CONTROL, RoutingType.MULTICAST, bridgeControlQueue)
val control = artemisSession.createConsumer(bridgeControlQueue)
controlConsumer = control
control.setMessageHandler { msg ->
try {
processControlMessage(msg)
} catch (ex: Exception) {
log.error("Unable to process bridge control message", ex)
}
}
val startupMessage = BridgeControl.BridgeToNodeSnapshotRequest(bridgeId).serialize(context = SerializationDefaults.P2P_CONTEXT).bytes
val bridgeRequest = artemisSession.createMessage(false)
bridgeRequest.writeBodyBufferBytes(startupMessage)
artemisClient.producer.send(BRIDGE_NOTIFY, bridgeRequest)
}
fun stop() {
controlConsumer?.close()
controlConsumer = null
artemis?.stop()
artemis = null
bridgeManager.stop()
}
override fun close() = stop()
private fun validateInboxQueueName(queueName: String): Boolean {
return queueName.startsWith(P2P_PREFIX) && artemis!!.started!!.session.queueQuery(SimpleString(queueName)).isExists
}
private fun validateBridgingQueueName(queueName: String): Boolean {
return queueName.startsWith(PEERS_PREFIX) && artemis!!.started!!.session.queueQuery(SimpleString(queueName)).isExists
}
private fun processControlMessage(msg: ClientMessage) {
val data: ByteArray = ByteArray(msg.bodySize).apply { msg.bodyBuffer.readBytes(this) }
val controlMessage = data.deserialize<BridgeControl>(context = SerializationDefaults.P2P_CONTEXT)
log.info("Received bridge control message $controlMessage")
when (controlMessage) {
is BridgeControl.NodeToBridgeSnapshot -> {
if (!controlMessage.inboxQueues.all { validateInboxQueueName(it) }) {
log.error("Invalid queue names in control message $controlMessage")
return
}
if (!controlMessage.sendQueues.all { validateBridgingQueueName(it.queueName) }) {
log.error("Invalid queue names in control message $controlMessage")
return
}
for (outQueue in controlMessage.sendQueues) {
bridgeManager.deployBridge(outQueue.queueName, outQueue.targets.first(), outQueue.legalNames.toSet())
}
// TODO For now we just record the inboxes, but we don't use the information, but eventually out of process bridges will use this for validating inbound messages.
validInboundQueues.addAll(controlMessage.inboxQueues)
}
is BridgeControl.BridgeToNodeSnapshotRequest -> {
log.error("Message from Bridge $controlMessage detected on wrong topic!")
}
is BridgeControl.Create -> {
if (!validateBridgingQueueName((controlMessage.bridgeInfo.queueName))) {
log.error("Invalid queue names in control message $controlMessage")
return
}
bridgeManager.deployBridge(controlMessage.bridgeInfo.queueName, controlMessage.bridgeInfo.targets.first(), controlMessage.bridgeInfo.legalNames.toSet())
}
is BridgeControl.Delete -> {
if (!controlMessage.bridgeInfo.queueName.startsWith(PEERS_PREFIX)) {
log.error("Invalid queue names in control message $controlMessage")
return
}
bridgeManager.destroyBridge(controlMessage.bridgeInfo.queueName, controlMessage.bridgeInfo.targets.first())
}
}
}
}

View File

@ -1,4 +1,4 @@
package net.corda.nodeapi.internal
package net.corda.nodeapi.internal.bridging
import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.CordaSerializable

View File

@ -0,0 +1,24 @@
package net.corda.nodeapi.internal.bridging
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.VisibleForTesting
import net.corda.core.node.NodeInfo
import net.corda.core.utilities.NetworkHostAndPort
/**
* Provides an internal interface that the [BridgeControlListener] delegates to for Bridge activities.
*/
@VisibleForTesting
interface BridgeManager : AutoCloseable {
fun deployBridge(queueName: String, target: NetworkHostAndPort, legalNames: Set<CordaX500Name>)
fun destroyBridges(node: NodeInfo)
fun destroyBridge(queueName: String, hostAndPort: NetworkHostAndPort)
fun bridgeExists(bridgeName: String): Boolean
fun start()
fun stop()
}

View File

@ -0,0 +1,484 @@
package net.corda.nodeapi.internal.protonwrapper.engine
import io.netty.buffer.ByteBuf
import io.netty.buffer.PooledByteBufAllocator
import io.netty.buffer.Unpooled
import io.netty.channel.Channel
import io.netty.channel.ChannelHandlerContext
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.debug
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.messages.impl.ReceivedMessageImpl
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import org.apache.qpid.proton.Proton
import org.apache.qpid.proton.amqp.Binary
import org.apache.qpid.proton.amqp.Symbol
import org.apache.qpid.proton.amqp.messaging.*
import org.apache.qpid.proton.amqp.messaging.Properties
import org.apache.qpid.proton.amqp.messaging.Target
import org.apache.qpid.proton.amqp.transaction.Coordinator
import org.apache.qpid.proton.amqp.transport.ErrorCondition
import org.apache.qpid.proton.amqp.transport.ReceiverSettleMode
import org.apache.qpid.proton.amqp.transport.SenderSettleMode
import org.apache.qpid.proton.engine.*
import org.apache.qpid.proton.message.Message
import org.apache.qpid.proton.message.ProtonJMessage
import org.slf4j.LoggerFactory
import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.util.*
/**
* This ConnectionStateMachine class handles the events generated by the proton-j library to track
* various logical connection, transport and link objects and to drive packet processing.
* It is single threaded per physical SSL connection just like the proton-j library,
* but this threading lock is managed by the EventProcessor class that calls this.
* It ultimately posts application packets to/from from the netty transport pipeline.
*/
internal class ConnectionStateMachine(serverMode: Boolean,
collector: Collector,
private val localLegalName: String,
private val remoteLegalName: String,
userName: String?,
password: String?) : BaseHandler() {
companion object {
private const val IDLE_TIMEOUT = 10000
}
val connection: Connection
private val log = LoggerFactory.getLogger(localLegalName)
private val transport: Transport
private val id = UUID.randomUUID().toString()
private var session: Session? = null
private val messageQueues = mutableMapOf<String, LinkedList<SendableMessageImpl>>()
private val unackedQueue = LinkedList<SendableMessageImpl>()
private val receivers = mutableMapOf<String, Receiver>()
private val senders = mutableMapOf<String, Sender>()
private var tagId: Int = 0
init {
connection = Engine.connection()
connection.container = "CORDA:$id"
transport = Engine.transport()
transport.idleTimeout = IDLE_TIMEOUT
transport.context = connection
transport.setEmitFlowEventOnSend(true)
connection.collect(collector)
val sasl = transport.sasl()
if (userName != null) {
//TODO This handshake is required for our queue permission logic in Artemis
sasl.setMechanisms("PLAIN")
if (serverMode) {
sasl.server()
sasl.done(Sasl.PN_SASL_OK)
} else {
sasl.plain(userName, password)
sasl.client()
}
} else {
sasl.setMechanisms("ANONYMOUS")
if (serverMode) {
sasl.server()
sasl.done(Sasl.PN_SASL_OK)
} else {
sasl.client()
}
}
transport.bind(connection)
if (!serverMode) {
connection.open()
}
}
override fun onConnectionInit(event: Event) {
val connection = event.connection
log.debug { "Connection init $connection" }
}
override fun onConnectionLocalOpen(event: Event) {
val connection = event.connection
log.info("Connection local open $connection")
val session = connection.session()
session.open()
this.session = session
for (target in messageQueues.keys) {
getSender(target)
}
}
override fun onConnectionLocalClose(event: Event) {
val connection = event.connection
log.info("Connection local close $connection")
connection.close()
connection.free()
}
override fun onConnectionUnbound(event: Event) {
if (event.connection == this.connection) {
val channel = connection.context as? Channel
if (channel != null) {
if (channel.isActive) {
channel.close()
}
}
}
}
override fun onConnectionFinal(event: Event) {
val connection = event.connection
log.debug { "Connection final $connection" }
if (connection == this.connection) {
this.connection.context = null
for (queue in messageQueues.values) {
// clear any dead messages
while (true) {
val msg = queue.poll()
if (msg != null) {
msg.doComplete(MessageStatus.Rejected)
msg.release()
} else {
break
}
}
}
messageQueues.clear()
while (true) {
val msg = unackedQueue.poll()
if (msg != null) {
msg.doComplete(MessageStatus.Rejected)
msg.release()
} else {
break
}
}
// shouldn't happen, but close socket channel now if not already done
val channel = connection.context as? Channel
if (channel != null && channel.isActive) {
channel.close()
}
// shouldn't happen, but cleanup any stranded items
transport.context = null
session = null
receivers.clear()
senders.clear()
}
}
override fun onTransportHeadClosed(event: Event) {
val transport = event.transport
log.debug { "Transport Head Closed $transport" }
transport.close_tail()
}
override fun onTransportTailClosed(event: Event) {
val transport = event.transport
log.debug { "Transport Tail Closed $transport" }
transport.close_head()
}
override fun onTransportClosed(event: Event) {
val transport = event.transport
log.debug { "Transport Closed $transport" }
if (transport == this.transport) {
transport.unbind()
transport.free()
transport.context = null
}
}
override fun onTransportError(event: Event) {
val transport = event.transport
log.info("Transport Error $transport")
val condition = event.transport.condition
if (condition != null) {
log.info("Error: ${condition.description}")
} else {
log.info("Error (no description returned).")
}
}
override fun onTransport(event: Event) {
val transport = event.transport
log.debug { "Transport $transport" }
onTransportInternal(transport)
}
private fun onTransportInternal(transport: Transport) {
if (!transport.isClosed) {
val pending = transport.pending() // Note this drives frame generation, which the susbsequent writes push to the socket
if (pending > 0) {
val connection = transport.context as? Connection
val channel = connection?.context as? Channel
channel?.writeAndFlush(transport)
}
}
}
override fun onSessionInit(event: Event) {
val session = event.session
log.debug { "Session init $session" }
}
override fun onSessionLocalOpen(event: Event) {
val session = event.session
log.debug { "Session local open $session" }
}
private fun getSender(target: String): Sender {
if (!senders.containsKey(target)) {
val sender = session!!.sender(UUID.randomUUID().toString())
sender.source = Source().apply {
address = target
dynamic = false
durable = TerminusDurability.NONE
}
sender.target = Target().apply {
address = target
dynamic = false
durable = TerminusDurability.UNSETTLED_STATE
}
sender.senderSettleMode = SenderSettleMode.UNSETTLED
sender.receiverSettleMode = ReceiverSettleMode.FIRST
senders[target] = sender
sender.open()
}
return senders[target]!!
}
override fun onSessionLocalClose(event: Event) {
val session = event.session
log.debug { "Session local close $session" }
session.close()
session.free()
}
override fun onSessionFinal(event: Event) {
val session = event.session
log.debug { "Session final $session" }
if (session == this.session) {
this.session = null
}
}
override fun onLinkLocalOpen(event: Event) {
val link = event.link
if (link is Sender) {
log.debug { "Sender Link local open ${link.name} ${link.source} ${link.target}" }
senders[link.target.address] = link
transmitMessages(link)
}
if (link is Receiver) {
log.debug { "Receiver Link local open ${link.name} ${link.source} ${link.target}" }
receivers[link.target.address] = link
}
}
override fun onLinkRemoteOpen(event: Event) {
val link = event.link
if (link is Receiver) {
if (link.remoteTarget is Coordinator) {
log.debug { "Coordinator link received" }
}
}
}
override fun onLinkFinal(event: Event) {
val link = event.link
if (link is Sender) {
log.debug { "Sender Link final ${link.name} ${link.source} ${link.target}" }
senders.remove(link.target.address)
}
if (link is Receiver) {
log.debug { "Receiver Link final ${link.name} ${link.source} ${link.target}" }
receivers.remove(link.target.address)
}
}
override fun onLinkFlow(event: Event) {
val link = event.link
if (link is Sender) {
log.debug { "Sender Flow event: ${link.name} ${link.source} ${link.target}" }
if (senders.containsKey(link.target.address)) {
transmitMessages(link)
}
} else if (link is Receiver) {
log.debug { "Receiver Flow event: ${link.name} ${link.source} ${link.target}" }
}
}
fun processTransport() {
onTransportInternal(transport)
}
private fun transmitMessages(sender: Sender) {
val messageQueue = messageQueues.getOrPut(sender.target.address, { LinkedList() })
while (sender.credit > 0) {
log.debug { "Sender credit: ${sender.credit}" }
val nextMessage = messageQueue.poll()
if (nextMessage != null) {
try {
val messageBuf = nextMessage.buf!!
val buf = ByteBuffer.allocate(4)
buf.putInt(tagId++)
val delivery = sender.delivery(buf.array())
delivery.context = nextMessage
sender.send(messageBuf.array(), messageBuf.arrayOffset() + messageBuf.readerIndex(), messageBuf.readableBytes())
nextMessage.status = MessageStatus.Sent
log.debug { "Put tag ${javax.xml.bind.DatatypeConverter.printHexBinary(delivery.tag)} on wire uuid: ${nextMessage.applicationProperties["_AMQ_DUPL_ID"]}" }
unackedQueue.offer(nextMessage)
sender.advance()
} finally {
nextMessage.release()
}
} else {
break
}
}
}
override fun onDelivery(event: Event) {
val delivery = event.delivery
log.debug { "Delivery $delivery" }
val link = delivery.link
if (link is Receiver) {
if (delivery.isReadable && !delivery.isPartial) {
val pending = delivery.pending()
val amqpMessage = decodeAMQPMessage(pending, link)
val payload = (amqpMessage.body as Data).value.array
val connection = event.connection
val channel = connection?.context as? Channel
if (channel != null) {
val appProperties = HashMap(amqpMessage.applicationProperties.value)
appProperties["_AMQ_VALIDATED_USER"] = remoteLegalName
val localAddress = channel.localAddress() as InetSocketAddress
val remoteAddress = channel.remoteAddress() as InetSocketAddress
val receivedMessage = ReceivedMessageImpl(
payload,
link.source.address,
remoteLegalName,
NetworkHostAndPort(localAddress.hostString, localAddress.port),
localLegalName,
NetworkHostAndPort(remoteAddress.hostString, remoteAddress.port),
appProperties,
channel,
delivery)
log.debug { "Full message received uuid: ${appProperties["_AMQ_DUPL_ID"]}" }
channel.writeAndFlush(receivedMessage)
if (link.current() == delivery) {
link.advance()
}
} else {
delivery.disposition(Rejected())
delivery.settle()
}
}
} else if (link is Sender) {
log.debug { "Sender delivery confirmed tag ${javax.xml.bind.DatatypeConverter.printHexBinary(delivery.tag)}" }
val ok = delivery.remotelySettled() && delivery.remoteState == Accepted.getInstance()
val sourceMessage = delivery.context as? SendableMessageImpl
unackedQueue.remove(sourceMessage)
sourceMessage?.doComplete(if (ok) MessageStatus.Acknowledged else MessageStatus.Rejected)
delivery.settle()
}
}
private fun encodeAMQPMessage(message: ProtonJMessage): ByteBuf {
val buffer = PooledByteBufAllocator.DEFAULT.heapBuffer(1500)
try {
try {
message.encode(NettyWritable(buffer))
val bytes = ByteArray(buffer.writerIndex())
buffer.readBytes(bytes)
return Unpooled.wrappedBuffer(bytes)
} catch (ex: Exception) {
log.error("Unable to encode message as AMQP packet", ex)
throw ex
}
} finally {
buffer.release()
}
}
private fun encodePayloadBytes(msg: SendableMessageImpl): ByteBuf {
val message = Proton.message() as ProtonJMessage
message.body = Data(Binary(msg.payload))
message.isDurable = true
message.properties = Properties()
val appProperties = HashMap(msg.applicationProperties)
//TODO We shouldn't have to do this, but Artemis Server doesn't set the header on AMQP packets.
// Fortunately, when we are bridge to bridge/bridge to float we can authenticate links there.
appProperties["_AMQ_VALIDATED_USER"] = localLegalName
message.applicationProperties = ApplicationProperties(appProperties)
return encodeAMQPMessage(message)
}
private fun decodeAMQPMessage(pending: Int, link: Receiver): Message {
val msgBuf = PooledByteBufAllocator.DEFAULT.heapBuffer(pending)
try {
link.recv(NettyWritable(msgBuf))
val amqpMessage = Proton.message()
amqpMessage.decode(msgBuf.array(), msgBuf.arrayOffset() + msgBuf.readerIndex(), msgBuf.readableBytes())
return amqpMessage
} finally {
msgBuf.release()
}
}
fun transportWriteMessage(msg: SendableMessageImpl) {
log.debug { "Queue application message write uuid: ${msg.applicationProperties["_AMQ_DUPL_ID"]} ${javax.xml.bind.DatatypeConverter.printHexBinary(msg.payload)}" }
msg.buf = encodePayloadBytes(msg)
val messageQueue = messageQueues.getOrPut(msg.topic, { LinkedList() })
messageQueue.offer(msg)
if (session != null) {
val sender = getSender(msg.topic)
transmitMessages(sender)
}
}
fun transportProcessInput(msg: ByteBuf) {
val source = msg.nioBuffer()
try {
do {
val buffer = transport.inputBuffer
val limit = Math.min(buffer.remaining(), source.remaining())
val duplicate = source.duplicate()
duplicate.limit(source.position() + limit)
buffer.put(duplicate)
transport.processInput().checkIsOk()
source.position(source.position() + limit)
} while (source.hasRemaining())
} catch (ex: Exception) {
val condition = ErrorCondition()
condition.condition = Symbol.getSymbol("proton:io")
condition.description = ex.message
transport.condition = condition
transport.close_tail()
transport.pop(Math.max(0, transport.pending())) // Force generation of TRANSPORT_HEAD_CLOSE (not in C code)
}
}
fun transportProcessOutput(ctx: ChannelHandlerContext) {
try {
var done = false
while (!done) {
val toWrite = transport.outputBuffer
if (toWrite != null && toWrite.hasRemaining()) {
val outbound = ctx.alloc().buffer(toWrite.remaining())
outbound.writeBytes(toWrite)
ctx.write(outbound)
transport.outputConsumed()
} else {
done = true
}
}
ctx.flush()
} catch (ex: Exception) {
val condition = ErrorCondition()
condition.condition = Symbol.getSymbol("proton:io")
condition.description = ex.message
transport.condition = condition
transport.close_head()
transport.pop(Math.max(0, transport.pending())) // Force generation of TRANSPORT_HEAD_CLOSE (not in C code)
}
}
}

View File

@ -0,0 +1,136 @@
package net.corda.nodeapi.internal.protonwrapper.engine
import io.netty.buffer.ByteBuf
import io.netty.channel.Channel
import io.netty.channel.ChannelHandlerContext
import net.corda.core.utilities.debug
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.messages.impl.ReceivedMessageImpl
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import org.apache.qpid.proton.Proton
import org.apache.qpid.proton.amqp.messaging.Accepted
import org.apache.qpid.proton.amqp.messaging.Rejected
import org.apache.qpid.proton.amqp.transport.DeliveryState
import org.apache.qpid.proton.amqp.transport.ErrorCondition
import org.apache.qpid.proton.engine.*
import org.apache.qpid.proton.engine.impl.CollectorImpl
import org.apache.qpid.proton.reactor.FlowController
import org.apache.qpid.proton.reactor.Handshaker
import org.slf4j.LoggerFactory
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
/**
* The EventProcessor class converts calls on the netty scheduler/pipeline
* into proton-j engine event calls into the ConnectionStateMachine.
* It also registers a couple of standard event processors for the basic connection handshake
* and simple sliding window flow control, so that these events don't have to live inside ConnectionStateMachine.
* Everything here is single threaded, because the proton-j library has to be run that way.
*/
internal class EventProcessor(channel: Channel,
serverMode: Boolean,
localLegalName: String,
remoteLegalName: String,
userName: String?,
password: String?) : BaseHandler() {
companion object {
private const val FLOW_WINDOW_SIZE = 10
}
private val log = LoggerFactory.getLogger(localLegalName)
private val lock = ReentrantLock()
private var pendingExecute: Boolean = false
private val executor: ScheduledExecutorService = channel.eventLoop()
private val collector = Proton.collector() as CollectorImpl
private val handlers = mutableListOf<Handler>()
private val stateMachine: ConnectionStateMachine = ConnectionStateMachine(serverMode,
collector,
localLegalName,
remoteLegalName,
userName,
password)
val connection: Connection = stateMachine.connection
init {
addHandler(Handshaker())
addHandler(FlowController(FLOW_WINDOW_SIZE))
addHandler(stateMachine)
connection.context = channel
tick(stateMachine.connection)
}
fun addHandler(handler: Handler) = handlers.add(handler)
private fun popEvent(): Event? {
var ev = collector.peek()
if (ev != null) {
ev = ev.copy() // prevent mutation by collector.pop()
collector.pop()
}
return ev
}
private fun tick(connection: Connection) {
lock.withLock {
try {
if ((connection.localState != EndpointState.CLOSED) && !connection.transport.isClosed) {
val now = System.currentTimeMillis()
val tickDelay = Math.max(0L, connection.transport.tick(now) - now)
executor.schedule({ tick(connection) }, tickDelay, TimeUnit.MILLISECONDS)
}
} catch (ex: Exception) {
connection.transport.close()
connection.condition = ErrorCondition()
}
}
}
fun processEvents() {
lock.withLock {
pendingExecute = false
log.debug { "Process Events" }
while (true) {
val ev = popEvent() ?: break
log.debug { "Process event: $ev" }
for (handler in handlers) {
handler.handle(ev)
}
}
stateMachine.processTransport()
log.debug { "Process Events Done" }
}
}
fun processEventsAsync() {
lock.withLock {
if (!pendingExecute) {
pendingExecute = true
executor.execute { processEvents() }
}
}
}
fun close() {
if (connection.localState != EndpointState.CLOSED) {
connection.close()
processEvents()
connection.free()
processEvents()
}
}
fun transportProcessInput(msg: ByteBuf) = lock.withLock { stateMachine.transportProcessInput(msg) }
fun transportProcessOutput(ctx: ChannelHandlerContext) = lock.withLock { stateMachine.transportProcessOutput(ctx) }
fun transportWriteMessage(msg: SendableMessageImpl) = lock.withLock { stateMachine.transportWriteMessage(msg) }
fun complete(completer: ReceivedMessageImpl.MessageCompleter) = lock.withLock {
val status: DeliveryState = if (completer.status == MessageStatus.Acknowledged) Accepted.getInstance() else Rejected()
completer.delivery.disposition(status)
completer.delivery.settle()
}
}

View File

@ -0,0 +1,63 @@
package net.corda.nodeapi.internal.protonwrapper.engine
import io.netty.buffer.ByteBuf
import org.apache.qpid.proton.codec.WritableBuffer
import java.nio.ByteBuffer
/**
* NettyWritable is a utility class allow proton-j encoders to write directly into a
* netty ByteBuf, without any need to materialize a ByteArray copy.
*/
internal class NettyWritable(val nettyBuffer: ByteBuf) : WritableBuffer {
override fun put(b: Byte) {
nettyBuffer.writeByte(b.toInt())
}
override fun putFloat(f: Float) {
nettyBuffer.writeFloat(f)
}
override fun putDouble(d: Double) {
nettyBuffer.writeDouble(d)
}
override fun put(src: ByteArray, offset: Int, length: Int) {
nettyBuffer.writeBytes(src, offset, length)
}
override fun putShort(s: Short) {
nettyBuffer.writeShort(s.toInt())
}
override fun putInt(i: Int) {
nettyBuffer.writeInt(i)
}
override fun putLong(l: Long) {
nettyBuffer.writeLong(l)
}
override fun hasRemaining(): Boolean {
return nettyBuffer.writerIndex() < nettyBuffer.capacity()
}
override fun remaining(): Int {
return nettyBuffer.capacity() - nettyBuffer.writerIndex()
}
override fun position(): Int {
return nettyBuffer.writerIndex()
}
override fun position(position: Int) {
nettyBuffer.writerIndex(position)
}
override fun put(payload: ByteBuffer) {
nettyBuffer.writeBytes(payload)
}
override fun limit(): Int {
return nettyBuffer.capacity()
}
}

View File

@ -0,0 +1,14 @@
package net.corda.nodeapi.internal.protonwrapper.messages
import net.corda.core.utilities.NetworkHostAndPort
/**
* Represents a common interface for both sendable and received application messages.
*/
interface ApplicationMessage {
val payload: ByteArray
val topic: String
val destinationLegalName: String
val destinationLink: NetworkHostAndPort
val applicationProperties: Map<Any?, Any?>
}

View File

@ -0,0 +1,11 @@
package net.corda.nodeapi.internal.protonwrapper.messages
/**
* The processing state of a message.
*/
enum class MessageStatus {
Unsent,
Sent,
Acknowledged,
Rejected
}

View File

@ -0,0 +1,13 @@
package net.corda.nodeapi.internal.protonwrapper.messages
import net.corda.core.utilities.NetworkHostAndPort
/**
* An extension of ApplicationMessage that includes origin information.
*/
interface ReceivedMessage : ApplicationMessage {
val sourceLegalName: String
val sourceLink: NetworkHostAndPort
fun complete(accepted: Boolean)
}

View File

@ -0,0 +1,10 @@
package net.corda.nodeapi.internal.protonwrapper.messages
import net.corda.core.concurrent.CordaFuture
/**
* An extension of ApplicationMessage to allow completion signalling.
*/
interface SendableMessage : ApplicationMessage {
val onComplete: CordaFuture<MessageStatus>
}

View File

@ -0,0 +1,30 @@
package net.corda.nodeapi.internal.protonwrapper.messages.impl
import io.netty.channel.Channel
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
import org.apache.qpid.proton.engine.Delivery
/**
* An internal packet management class that allows tracking of asynchronous acknowledgements
* that in turn send Delivery messages back to the originator.
*/
internal class ReceivedMessageImpl(override val payload: ByteArray,
override val topic: String,
override val sourceLegalName: String,
override val sourceLink: NetworkHostAndPort,
override val destinationLegalName: String,
override val destinationLink: NetworkHostAndPort,
override val applicationProperties: Map<Any?, Any?>,
private val channel: Channel,
private val delivery: Delivery) : ReceivedMessage {
data class MessageCompleter(val status: MessageStatus, val delivery: Delivery)
override fun complete(accepted: Boolean) {
val status = if (accepted) MessageStatus.Acknowledged else MessageStatus.Rejected
channel.writeAndFlush(MessageCompleter(status, delivery))
}
override fun toString(): String = "Received ${String(payload)} $topic"
}

View File

@ -0,0 +1,37 @@
package net.corda.nodeapi.internal.protonwrapper.messages.impl
import io.netty.buffer.ByteBuf
import net.corda.core.concurrent.CordaFuture
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
/**
* An internal packet management class that allows handling of the encoded buffers and
* allows registration of an acknowledgement handler when the remote receiver confirms durable storage.
*/
internal class SendableMessageImpl(override val payload: ByteArray,
override val topic: String,
override val destinationLegalName: String,
override val destinationLink: NetworkHostAndPort,
override val applicationProperties: Map<Any?, Any?>) : SendableMessage {
var buf: ByteBuf? = null
@Volatile
var status: MessageStatus = MessageStatus.Unsent
private val _onComplete = openFuture<MessageStatus>()
override val onComplete: CordaFuture<MessageStatus> get() = _onComplete
fun release() {
buf?.release()
buf = null
}
fun doComplete(status: MessageStatus) {
this.status = status
_onComplete.set(status)
}
override fun toString(): String = "Sendable ${String(payload)} $topic $status"
}

View File

@ -0,0 +1,156 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelDuplexHandler
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelPromise
import io.netty.channel.socket.SocketChannel
import io.netty.handler.ssl.SslHandler
import io.netty.handler.ssl.SslHandshakeCompletionEvent
import io.netty.util.ReferenceCountUtil
import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.debug
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.protonwrapper.engine.EventProcessor
import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.ReceivedMessageImpl
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import org.apache.qpid.proton.engine.ProtonJTransport
import org.apache.qpid.proton.engine.Transport
import org.apache.qpid.proton.engine.impl.ProtocolTracer
import org.apache.qpid.proton.framing.TransportFrame
import org.slf4j.LoggerFactory
import java.net.InetSocketAddress
import java.security.cert.X509Certificate
/**
* An instance of AMQPChannelHandler sits inside the netty pipeline and controls the socket level lifecycle.
* It also add some extra checks to the SSL handshake to support our non-standard certificate checks of legal identity.
* When a valid SSL connections is made then it initialises a proton-j engine instance to handle the protocol layer.
*/
internal class AMQPChannelHandler(private val serverMode: Boolean,
private val allowedRemoteLegalNames: Set<CordaX500Name>?,
private val userName: String?,
private val password: String?,
private val trace: Boolean,
private val onOpen: (Pair<SocketChannel, ConnectionChange>) -> Unit,
private val onClose: (Pair<SocketChannel, ConnectionChange>) -> Unit,
private val onReceive: (ReceivedMessage) -> Unit) : ChannelDuplexHandler() {
private val log = LoggerFactory.getLogger(allowedRemoteLegalNames?.firstOrNull()?.toString() ?: "AMQPChannelHandler")
private lateinit var remoteAddress: InetSocketAddress
private lateinit var localCert: X509Certificate
private lateinit var remoteCert: X509Certificate
private var eventProcessor: EventProcessor? = null
override fun channelActive(ctx: ChannelHandlerContext) {
val ch = ctx.channel()
remoteAddress = ch.remoteAddress() as InetSocketAddress
val localAddress = ch.localAddress() as InetSocketAddress
log.info("New client connection ${ch.id()} from $remoteAddress to $localAddress")
}
private fun createAMQPEngine(ctx: ChannelHandlerContext) {
val ch = ctx.channel()
eventProcessor = EventProcessor(ch, serverMode, localCert.subjectX500Principal.toString(), remoteCert.subjectX500Principal.toString(), userName, password)
val connection = eventProcessor!!.connection
val transport = connection.transport as ProtonJTransport
if (trace) {
transport.protocolTracer = object : ProtocolTracer {
override fun sentFrame(transportFrame: TransportFrame) {
log.info("${transportFrame.body}")
}
override fun receivedFrame(transportFrame: TransportFrame) {
log.info("${transportFrame.body}")
}
}
}
ctx.fireChannelActive()
eventProcessor!!.processEventsAsync()
}
override fun channelInactive(ctx: ChannelHandlerContext) {
val ch = ctx.channel()
log.info("Closed client connection ${ch.id()} from $remoteAddress to ${ch.localAddress()}")
onClose(Pair(ch as SocketChannel, ConnectionChange(remoteAddress, null, false)))
eventProcessor?.close()
ctx.fireChannelInactive()
}
override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) {
if (evt is SslHandshakeCompletionEvent) {
if (evt.isSuccess) {
val sslHandler = ctx.pipeline().get(SslHandler::class.java)
localCert = sslHandler.engine().session.localCertificates[0].x509
remoteCert = sslHandler.engine().session.peerCertificates[0].x509
try {
val remoteX500Name = CordaX500Name.build(remoteCert.subjectX500Principal)
require(allowedRemoteLegalNames == null || remoteX500Name in allowedRemoteLegalNames)
log.info("handshake completed subject: $remoteX500Name")
} catch (ex: IllegalArgumentException) {
log.error("Invalid certificate subject", ex)
ctx.close()
return
}
createAMQPEngine(ctx)
onOpen(Pair(ctx.channel() as SocketChannel, ConnectionChange(remoteAddress, remoteCert, true)))
} else {
log.error("Handshake failure $evt")
ctx.close()
}
}
}
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
try {
log.debug { "Received $msg" }
if (msg is ByteBuf) {
eventProcessor!!.transportProcessInput(msg)
}
} finally {
ReferenceCountUtil.release(msg)
}
eventProcessor!!.processEventsAsync()
}
override fun write(ctx: ChannelHandlerContext, msg: Any, promise: ChannelPromise) {
try {
try {
log.debug { "Sent $msg" }
when (msg) {
// Transfers application packet into the AMQP engine.
is SendableMessageImpl -> {
val inetAddress = InetSocketAddress(msg.destinationLink.host, msg.destinationLink.port)
require(inetAddress == remoteAddress) {
"Message for incorrect endpoint"
}
require(CordaX500Name.parse(msg.destinationLegalName) == CordaX500Name.build(remoteCert.subjectX500Principal)) {
"Message for incorrect legal identity"
}
log.debug { "channel write ${msg.applicationProperties["_AMQ_DUPL_ID"]}" }
eventProcessor!!.transportWriteMessage(msg)
}
// A received AMQP packet has been completed and this self-posted packet will be signalled out to the
// external application.
is ReceivedMessage -> {
onReceive(msg)
}
// A general self-posted event that triggers creation of AMQP frames when required.
is Transport -> {
eventProcessor!!.transportProcessOutput(ctx)
}
// A self-posted event that forwards status updates for delivered packets to the application.
is ReceivedMessageImpl.MessageCompleter -> {
eventProcessor!!.complete(msg)
}
}
} catch (ex: Exception) {
log.error("Error in AMQP write processing", ex)
throw ex
}
} finally {
ReferenceCountUtil.release(msg)
}
eventProcessor!!.processEventsAsync()
}
}

View File

@ -0,0 +1,194 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.bootstrap.Bootstrap
import io.netty.channel.*
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioSocketChannel
import io.netty.handler.logging.LogLevel
import io.netty.handler.logging.LoggingHandler
import io.netty.util.internal.logging.InternalLoggerFactory
import io.netty.util.internal.logging.Slf4JLoggerFactory
import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import rx.Observable
import rx.subjects.PublishSubject
import java.security.KeyStore
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock
/**
* The AMQPClient creates a connection initiator that will try to connect in a round-robin fashion
* to the first open SSL socket. It will keep retrying until it is stopped.
* To allow thread resource control it can accept a shared thread pool as constructor input,
* otherwise it creates a self-contained Netty thraed pool and socket objects.
* Once connected it can accept application packets to send via the AMQP protocol.
*/
class AMQPClient(val targets: List<NetworkHostAndPort>,
val allowedRemoteLegalNames: Set<CordaX500Name>,
private val userName: String?,
private val password: String?,
private val keyStore: KeyStore,
private val keyStorePrivateKeyPassword: String,
private val trustStore: KeyStore,
private val trace: Boolean = false,
private val sharedThreadPool: EventLoopGroup? = null) : AutoCloseable {
companion object {
init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
}
val log = contextLogger()
const val RETRY_INTERVAL = 1000L
const val NUM_CLIENT_THREADS = 2
}
private val lock = ReentrantLock()
@Volatile
private var stopping: Boolean = false
private var workerGroup: EventLoopGroup? = null
@Volatile
private var clientChannel: Channel? = null
// Offset into the list of targets, so that we can implement round-robin reconnect logic.
private var targetIndex = 0
private var currentTarget: NetworkHostAndPort = targets.first()
private val connectListener = object : ChannelFutureListener {
override fun operationComplete(future: ChannelFuture) {
if (!future.isSuccess) {
log.info("Failed to connect to $currentTarget")
if (!stopping) {
workerGroup?.schedule({
log.info("Retry connect to $currentTarget")
targetIndex = (targetIndex + 1).rem(targets.size)
restart()
}, RETRY_INTERVAL, TimeUnit.MILLISECONDS)
}
} else {
log.info("Connected to $currentTarget")
// Connection established successfully
clientChannel = future.channel()
clientChannel?.closeFuture()?.addListener(closeListener)
}
}
}
private val closeListener = object : ChannelFutureListener {
override fun operationComplete(future: ChannelFuture) {
log.info("Disconnected from $currentTarget")
future.channel()?.disconnect()
clientChannel = null
if (!stopping) {
workerGroup?.schedule({
log.info("Retry connect")
targetIndex = (targetIndex + 1).rem(targets.size)
restart()
}, RETRY_INTERVAL, TimeUnit.MILLISECONDS)
}
}
}
private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
init {
keyManagerFactory.init(parent.keyStore, parent.keyStorePrivateKeyPassword.toCharArray())
trustManagerFactory.init(parent.trustStore)
}
override fun initChannel(ch: SocketChannel) {
val pipeline = ch.pipeline()
val handler = createClientSslHelper(parent.currentTarget, keyManagerFactory, trustManagerFactory)
pipeline.addLast("sslHandler", handler)
if (parent.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO))
pipeline.addLast(AMQPChannelHandler(false,
parent.allowedRemoteLegalNames,
parent.userName,
parent.password,
parent.trace,
{ parent._onConnection.onNext(it.second) },
{ parent._onConnection.onNext(it.second) },
{ rcv -> parent._onReceive.onNext(rcv) }))
}
}
fun start() {
lock.withLock {
log.info("connect to: $currentTarget")
workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS)
restart()
}
}
private fun restart() {
val bootstrap = Bootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
bootstrap.group(workerGroup).
channel(NioSocketChannel::class.java).
handler(ClientChannelInitializer(this))
currentTarget = targets[targetIndex]
val clientFuture = bootstrap.connect(currentTarget.host, currentTarget.port)
clientFuture.addListener(connectListener)
}
fun stop() {
lock.withLock {
log.info("disconnect from: $currentTarget")
stopping = true
try {
if (sharedThreadPool == null) {
workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync()
} else {
clientChannel?.close()?.sync()
}
clientChannel = null
workerGroup = null
} finally {
stopping = false
}
log.info("stopped connection to $currentTarget")
}
}
override fun close() = stop()
val connected: Boolean
get() {
val channel = lock.withLock { clientChannel }
return channel?.isActive ?: false
}
fun createMessage(payload: ByteArray,
topic: String,
destinationLegalName: String,
properties: Map<Any?, Any?>): SendableMessage {
return SendableMessageImpl(payload, topic, destinationLegalName, currentTarget, properties)
}
fun write(msg: SendableMessage) {
val channel = clientChannel
if (channel == null) {
throw IllegalStateException("Connection to $targets not active")
} else {
channel.writeAndFlush(msg)
}
}
private val _onReceive = PublishSubject.create<ReceivedMessage>().toSerialized()
val onReceive: Observable<ReceivedMessage>
get() = _onReceive
private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized()
val onConnection: Observable<ConnectionChange>
get() = _onConnection
}

View File

@ -0,0 +1,187 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.Channel
import io.netty.channel.ChannelInitializer
import io.netty.channel.ChannelOption
import io.netty.channel.EventLoopGroup
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.logging.LogLevel
import io.netty.handler.logging.LoggingHandler
import io.netty.util.internal.logging.InternalLoggerFactory
import io.netty.util.internal.logging.Slf4JLoggerFactory
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import org.apache.qpid.proton.engine.Delivery
import rx.Observable
import rx.subjects.PublishSubject
import java.net.BindException
import java.net.InetSocketAddress
import java.security.KeyStore
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock
/**
* This create a socket acceptor instance that can receive possibly multiple AMQP connections.
* As of now this is not used outside of testing, but in future it will be used for standalone bridging components.
*/
class AMQPServer(val hostName: String,
val port: Int,
private val userName: String?,
private val password: String?,
private val keyStore: KeyStore,
private val keyStorePrivateKeyPassword: String,
private val trustStore: KeyStore,
private val trace: Boolean = false) : AutoCloseable {
companion object {
init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
}
private val log = contextLogger()
const val NUM_SERVER_THREADS = 4
}
private val lock = ReentrantLock()
@Volatile
private var stopping: Boolean = false
private var bossGroup: EventLoopGroup? = null
private var workerGroup: EventLoopGroup? = null
private var serverChannel: Channel? = null
private val clientChannels = ConcurrentHashMap<InetSocketAddress, SocketChannel>()
init {
}
private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
init {
keyManagerFactory.init(parent.keyStore, parent.keyStorePrivateKeyPassword.toCharArray())
trustManagerFactory.init(parent.trustStore)
}
override fun initChannel(ch: SocketChannel) {
val pipeline = ch.pipeline()
val handler = createServerSslHelper(keyManagerFactory, trustManagerFactory)
pipeline.addLast("sslHandler", handler)
if (parent.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO))
pipeline.addLast(AMQPChannelHandler(true,
null,
parent.userName,
parent.password,
parent.trace,
{
parent.clientChannels.put(it.first.remoteAddress(), it.first)
parent._onConnection.onNext(it.second)
},
{
parent.clientChannels.remove(it.first.remoteAddress())
parent._onConnection.onNext(it.second)
},
{ rcv -> parent._onReceive.onNext(rcv) }))
}
}
fun start() {
lock.withLock {
stop()
bossGroup = NioEventLoopGroup(1)
workerGroup = NioEventLoopGroup(NUM_SERVER_THREADS)
val server = ServerBootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
server.group(bossGroup, workerGroup).
channel(NioServerSocketChannel::class.java).
option(ChannelOption.SO_BACKLOG, 100).
handler(LoggingHandler(LogLevel.INFO)).
childHandler(ServerChannelInitializer(this))
log.info("Try to bind $port")
val channelFuture = server.bind(hostName, port).sync() // block/throw here as better to know we failed to claim port than carry on
if (!channelFuture.isDone || !channelFuture.isSuccess) {
throw BindException("Failed to bind port $port")
}
log.info("Listening on port $port")
serverChannel = channelFuture.channel()
}
}
fun stop() {
lock.withLock {
try {
stopping = true
serverChannel?.apply { close() }
serverChannel = null
workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync()
bossGroup?.shutdownGracefully()
bossGroup?.terminationFuture()?.sync()
workerGroup = null
bossGroup = null
} finally {
stopping = false
}
}
}
override fun close() = stop()
val listening: Boolean
get() {
val channel = lock.withLock { serverChannel }
return channel?.isActive ?: false
}
fun createMessage(payload: ByteArray,
topic: String,
destinationLegalName: String,
destinationLink: NetworkHostAndPort,
properties: Map<Any?, Any?>): SendableMessage {
val dest = InetSocketAddress(destinationLink.host, destinationLink.port)
require(dest in clientChannels.keys) {
"Destination not available"
}
return SendableMessageImpl(payload, topic, destinationLegalName, destinationLink, properties)
}
fun write(msg: SendableMessage) {
val dest = InetSocketAddress(msg.destinationLink.host, msg.destinationLink.port)
val channel = clientChannels[dest]
if (channel == null) {
throw IllegalStateException("Connection to ${msg.destinationLink} not active")
} else {
channel.writeAndFlush(msg)
}
}
fun complete(delivery: Delivery, target: InetSocketAddress) {
val channel = clientChannels[target]
channel?.apply {
writeAndFlush(delivery)
}
}
private val _onReceive = PublishSubject.create<ReceivedMessage>().toSerialized()
val onReceive: Observable<ReceivedMessage>
get() = _onReceive
private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized()
val onConnection: Observable<ConnectionChange>
get() = _onConnection
}

View File

@ -0,0 +1,6 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import java.net.InetSocketAddress
import java.security.cert.X509Certificate
data class ConnectionChange(val remoteAddress: InetSocketAddress, val remoteCert: X509Certificate?, val connected: Boolean)

View File

@ -0,0 +1,39 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.handler.ssl.SslHandler
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.ArtemisTcpTransport
import java.security.SecureRandom
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
internal fun createClientSslHelper(target: NetworkHostAndPort,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers
sslContext.init(keyManagers, trustManagers, SecureRandom())
val sslEngine = sslContext.createSSLEngine(target.host, target.port)
sslEngine.useClientMode = true
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
sslEngine.enabledCipherSuites = ArtemisTcpTransport.CIPHER_SUITES.toTypedArray()
sslEngine.enableSessionCreation = true
return SslHandler(sslEngine)
}
internal fun createServerSslHelper(keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers
sslContext.init(keyManagers, trustManagers, SecureRandom())
val sslEngine = sslContext.createSSLEngine()
sslEngine.useClientMode = false
sslEngine.needClientAuth = true
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
sslEngine.enabledCipherSuites = ArtemisTcpTransport.CIPHER_SUITES.toTypedArray()
sslEngine.enableSessionCreation = true
return SslHandler(sslEngine)
}