ENT-9806: Using Artemis SSLContextFactory service to pass in custom TrustManagerFactory. This removes the need to copy code from NettyAcceptor.

This commit is contained in:
Shams Asari 2023-06-01 17:33:04 +01:00
parent 5706f89639
commit 4dcd9245d3
8 changed files with 80 additions and 132 deletions

View File

@ -183,10 +183,7 @@ class ArtemisTcpTransport {
options[TransportConstants.HANDSHAKE_TIMEOUT] = 0 options[TransportConstants.HANDSHAKE_TIMEOUT] = 0
if (trustManagerFactory != null) { if (trustManagerFactory != null) {
// NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use // NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use
// more customised instances which use our revocation checkers, which we pass directly into NodeNettyAcceptorFactory. // more customised instances which use our revocation checkers, so we pass them in, to be picked up by Node(Open)SSLContextFactory.
//
// This, however, requires copying a lot of code from NettyAcceptor into NodeNettyAcceptor. The version of Artemis in
// Corda 4.9 solves this problem by introducing a "trustManagerFactoryPlugin" config option.
options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory
} }
return createTransport( return createTransport(
@ -208,6 +205,10 @@ class ArtemisTcpTransport {
threadPoolName: String, threadPoolName: String,
trace: Boolean, trace: Boolean,
remotingThreads: Int?): TransportConfiguration { remotingThreads: Int?): TransportConfiguration {
if (enableSSL) {
// This is required to stop Client checking URL address vs. Server provided certificate
options[TransportConstants.VERIFY_HOST_PROP_NAME] = false
}
return createTransport( return createTransport(
NodeNettyConnectorFactory::class.java.name, NodeNettyConnectorFactory::class.java.name,
hostAndPort, hostAndPort,
@ -232,8 +233,6 @@ class ArtemisTcpTransport {
if (enableSSL) { if (enableSSL) {
options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",") options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",")
options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",") options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",")
// This is required to stop Client checking URL address vs. Server provided certificate
options[TransportConstants.VERIFY_HOST_PROP_NAME] = false
} }
// By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357) // By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357)
options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1 options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1

View File

@ -497,7 +497,9 @@ class ArtemisServerRevocationTest : AbstractServerRevocationTest() {
} }
val queueName = "${P2P_PREFIX}Test" val queueName = "${P2P_PREFIX}Test"
artemisNode.client.started!!.session.createQueue(queueName, RoutingType.ANYCAST, queueName, true) artemisNode.client.started!!.session.createQueue(
QueueConfiguration(queueName).setRoutingType(RoutingType.ANYCAST).setAddress(queueName).setDurable(true)
)
val clientConnectionChangeStatus = client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus) val clientConnectionChangeStatus = client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus)

View File

@ -140,7 +140,7 @@ class BrokerJaasLoginModule : BaseBrokerJaasLoginModule() {
// This check is redundant as it was performed already during the SSL handshake // This check is redundant as it was performed already during the SSL handshake
CertificateChainCheckPolicy.RootMustMatch CertificateChainCheckPolicy.RootMustMatch
.createCheck(p2pJaasConfig.keyStore, p2pJaasConfig.trustStore) .createCheck(p2pJaasConfig.keyStore, p2pJaasConfig.trustStore)
.checkCertificateChain(certificates!!) .checkCertificateChain(certificates)
Pair(certificates.first().subjectDN.name, listOf(RolePrincipal(PEER_ROLE))) Pair(certificates.first().subjectDN.name, listOf(RolePrincipal(PEER_ROLE)))
} }
else -> { else -> {

View File

@ -42,7 +42,6 @@ import org.apache.activemq.artemis.core.security.Role
import org.apache.activemq.artemis.core.server.ActiveMQServer import org.apache.activemq.artemis.core.server.ActiveMQServer
import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl
import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager
import java.io.IOException
import java.lang.Long.max import java.lang.Long.max
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
import javax.security.auth.login.AppConfigurationEntry import javax.security.auth.login.AppConfigurationEntry

View File

@ -5,24 +5,14 @@ import io.netty.channel.ChannelHandlerContext
import io.netty.channel.group.ChannelGroup import io.netty.channel.group.ChannelGroup
import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LogLevel
import io.netty.handler.logging.LoggingHandler import io.netty.handler.logging.LoggingHandler
import io.netty.handler.ssl.SslContext
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.SslHandler import io.netty.handler.ssl.SslHandler
import io.netty.handler.ssl.SslHandshakeTimeoutException import io.netty.handler.ssl.SslHandshakeTimeoutException
import io.netty.handler.ssl.SslProvider
import net.corda.core.internal.declaredField import net.corda.core.internal.declaredField
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.ArtemisTcpTransport import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.createAndInitSslContext
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor
import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration
import org.apache.activemq.artemis.api.core.BaseInterceptor import org.apache.activemq.artemis.api.core.BaseInterceptor
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor
import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants
import org.apache.activemq.artemis.core.remoting.impl.ssl.SSLSupport
import org.apache.activemq.artemis.core.server.ActiveMQServerLogger
import org.apache.activemq.artemis.core.server.balancing.RedirectHandler import org.apache.activemq.artemis.core.server.balancing.RedirectHandler
import org.apache.activemq.artemis.core.server.cluster.ClusterConnection import org.apache.activemq.artemis.core.server.cluster.ClusterConnection
import org.apache.activemq.artemis.spi.core.protocol.ProtocolManager import org.apache.activemq.artemis.spi.core.protocol.ProtocolManager
@ -30,24 +20,20 @@ import org.apache.activemq.artemis.spi.core.remoting.Acceptor
import org.apache.activemq.artemis.spi.core.remoting.AcceptorFactory import org.apache.activemq.artemis.spi.core.remoting.AcceptorFactory
import org.apache.activemq.artemis.spi.core.remoting.BufferHandler import org.apache.activemq.artemis.spi.core.remoting.BufferHandler
import org.apache.activemq.artemis.spi.core.remoting.ServerConnectionLifeCycleListener import org.apache.activemq.artemis.spi.core.remoting.ServerConnectionLifeCycleListener
import org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactoryProvider
import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactoryProvider
import org.apache.activemq.artemis.utils.ConfigurationHelper import org.apache.activemq.artemis.utils.ConfigurationHelper
import org.apache.activemq.artemis.utils.actors.OrderedExecutor import org.apache.activemq.artemis.utils.actors.OrderedExecutor
import java.net.SocketAddress import java.net.SocketAddress
import java.nio.channels.ClosedChannelException import java.nio.channels.ClosedChannelException
import java.nio.file.Paths
import java.security.PrivilegedExceptionAction
import java.time.Duration import java.time.Duration
import java.util.concurrent.Executor import java.util.concurrent.Executor
import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.ScheduledExecutorService
import java.util.regex.Pattern import java.util.regex.Pattern
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLEngine import javax.net.ssl.SSLEngine
import javax.net.ssl.SSLPeerUnverifiedException import javax.net.ssl.SSLPeerUnverifiedException
import javax.net.ssl.TrustManagerFactory
import javax.security.auth.Subject
@Suppress("unused", "TooGenericExceptionCaught", "ComplexMethod", "MagicNumber", "TooManyFunctions") @Suppress("unused") // Used via reflection in ArtemisTcpTransport
class NodeNettyAcceptorFactory : AcceptorFactory { class NodeNettyAcceptorFactory : AcceptorFactory {
override fun createAcceptor(name: String?, override fun createAcceptor(name: String?,
clusterConnection: ClusterConnection?, clusterConnection: ClusterConnection?,
@ -74,6 +60,12 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
{ {
companion object { companion object {
private val defaultThreadPoolNamePattern = Pattern.compile("""Thread-(\d+) \(activemq-netty-threads\)""") private val defaultThreadPoolNamePattern = Pattern.compile("""Thread-(\d+) \(activemq-netty-threads\)""")
init {
// Make sure Artemis isn't using another (Open)SSLContextFactory
check(SSLContextFactoryProvider.getSSLContextFactory() is NodeSSLContextFactory)
check(OpenSSLContextFactoryProvider.getOpenSSLContextFactory() is NodeOpenSSLContextFactory)
}
} }
private val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "NodeNettyAcceptor", configuration) private val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "NodeNettyAcceptor", configuration)
@ -100,7 +92,7 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
@Synchronized @Synchronized
override fun getSslHandler(alloc: ByteBufAllocator?, peerHost: String?, peerPort: Int): SslHandler { override fun getSslHandler(alloc: ByteBufAllocator?, peerHost: String?, peerPort: Int): SslHandler {
applyThreadPoolName() applyThreadPoolName()
val engine = getSSLEngine(alloc, peerHost, peerPort) val engine = super.getSslHandler(alloc, peerHost, peerPort).engine()
val sslHandler = NodeAcceptorSslHandler(engine, sslDelegatedTaskExecutor, trace) val sslHandler = NodeAcceptorSslHandler(engine, sslDelegatedTaskExecutor, trace)
val handshakeTimeout = configuration[ArtemisTcpTransport.SSL_HANDSHAKE_TIMEOUT_NAME] as Duration? val handshakeTimeout = configuration[ArtemisTcpTransport.SSL_HANDSHAKE_TIMEOUT_NAME] as Duration?
if (handshakeTimeout != null) { if (handshakeTimeout != null) {
@ -118,111 +110,6 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
Thread.currentThread().name = "$threadPoolName-${matcher.group(1)}" // Preserve the pool thread number Thread.currentThread().name = "$threadPoolName-${matcher.group(1)}" // Preserve the pool thread number
} }
} }
/**
* This is a copy of [NettyAcceptor.getSslHandler] so that we can provide different implementations for [loadOpenSslEngine] and
* [loadJdkSslEngine]. [NodeNettyAcceptor], instead of creating a default [TrustManagerFactory], will simply use the provided one in
* the [ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] configuration.
*/
private fun getSSLEngine(alloc: ByteBufAllocator?): SSLEngine {
val engine = if (sslProvider == TransportConstants.OPENSSL_PROVIDER) {
loadOpenSslEngine(alloc)
} else {
loadJdkSslEngine()
}
engine.useClientMode = false
if (needClientAuth) {
engine.needClientAuth = true
}
// setting the enabled cipher suites resets the enabled protocols so we need
// to save the enabled protocols so that after the customer cipher suite is enabled
// we can reset the enabled protocols if a customer protocol isn't specified
val originalProtocols = engine.enabledProtocols
if (enabledCipherSuites != null) {
try {
engine.enabledCipherSuites = SSLSupport.parseCommaSeparatedListIntoArray(enabledCipherSuites)
} catch (e: IllegalArgumentException) {
ActiveMQServerLogger.LOGGER.invalidCipherSuite(SSLSupport.parseArrayIntoCommandSeparatedList(engine.supportedCipherSuites))
throw e
}
}
if (enabledProtocols != null) {
try {
engine.enabledProtocols = SSLSupport.parseCommaSeparatedListIntoArray(enabledProtocols)
} catch (e: IllegalArgumentException) {
ActiveMQServerLogger.LOGGER.invalidProtocol(SSLSupport.parseArrayIntoCommandSeparatedList(engine.supportedProtocols))
throw e
}
} else {
engine.enabledProtocols = originalProtocols
}
return engine
}
/**
* Copy of [NettyAcceptor.loadOpenSslEngine] which invokes our custom [createOpenSslContext].
*/
private fun loadOpenSslEngine(alloc: ByteBufAllocator?): SSLEngine {
val context = try {
// We copied all this code just so we could replace the SSLSupport.createNettyContext method call with our own one.
createOpenSslContext()
} catch (e: Exception) {
throw IllegalStateException("Unable to create NodeNettyAcceptor", e)
}
return Subject.doAs<SSLEngine>(null, PrivilegedExceptionAction {
context.newEngine(alloc)
})
}
/**
* Copy of [NettyAcceptor.loadJdkSslEngine] which invokes our custom [createJdkSSLContext].
*/
private fun loadJdkSslEngine(): SSLEngine {
val context = try {
// We copied all this code just so we could replace the SSLHelper.createContext method call with our own one.
createJdkSSLContext()
} catch (e: Exception) {
throw IllegalStateException("Unable to create NodeNettyAcceptor", e)
}
return Subject.doAs<SSLEngine>(null, PrivilegedExceptionAction {
context.createSSLEngine()
})
}
/**
* Create an [SSLContext] using the [TrustManagerFactory] provided on the [ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] config.
*/
private fun createJdkSSLContext(): SSLContext {
return createAndInitSslContext(
createKeyManagerFactory(),
configuration[ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?
)
}
/**
* Create an [SslContext] using the the [TrustManagerFactory] provided on the [ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] config.
*/
private fun createOpenSslContext(): SslContext {
return SslContextBuilder
.forServer(createKeyManagerFactory())
.sslProvider(SslProvider.OPENSSL)
.trustManager(configuration[ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?)
.build()
}
private fun createKeyManagerFactory(): KeyManagerFactory {
return keyManagerFactory(CertificateStore.fromFile(Paths.get(keyStorePath), keyStorePassword, keyStorePassword, false))
}
// Replicate the fields which are private in NettyAcceptor
private val sslProvider = ConfigurationHelper.getStringProperty(TransportConstants.SSL_PROVIDER, TransportConstants.DEFAULT_SSL_PROVIDER, configuration)
private val needClientAuth = ConfigurationHelper.getBooleanProperty(TransportConstants.NEED_CLIENT_AUTH_PROP_NAME, TransportConstants.DEFAULT_NEED_CLIENT_AUTH, configuration)
private val enabledCipherSuites = ConfigurationHelper.getStringProperty(TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME, TransportConstants.DEFAULT_ENABLED_CIPHER_SUITES, configuration)
private val enabledProtocols = ConfigurationHelper.getStringProperty(TransportConstants.ENABLED_PROTOCOLS_PROP_NAME, TransportConstants.DEFAULT_ENABLED_PROTOCOLS, configuration)
private val keyStorePath = ConfigurationHelper.getStringProperty(TransportConstants.KEYSTORE_PATH_PROP_NAME, TransportConstants.DEFAULT_KEYSTORE_PATH, configuration)
private val keyStoreProvider = ConfigurationHelper.getStringProperty(TransportConstants.KEYSTORE_PROVIDER_PROP_NAME, TransportConstants.DEFAULT_KEYSTORE_PROVIDER, configuration)
private val keyStorePassword = ConfigurationHelper.getPasswordProperty(TransportConstants.KEYSTORE_PASSWORD_PROP_NAME, TransportConstants.DEFAULT_KEYSTORE_PASSWORD, configuration, ActiveMQDefaultConfiguration.getPropMaskPassword(), ActiveMQDefaultConfiguration.getPropPasswordCodec())
} }

View File

@ -0,0 +1,59 @@
package net.corda.node.services.messaging
import io.netty.handler.ssl.SslContext
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.SslProvider
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.TRUST_MANAGER_FACTORY_NAME
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.createAndInitSslContext
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultOpenSSLContextFactory
import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultSSLContextFactory
import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextConfig
import java.nio.file.Paths
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
class NodeSSLContextFactory : DefaultSSLContextFactory() {
override fun getSSLContext(config: SSLContextConfig, additionalOpts: Map<String, Any>): SSLContext {
val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?
return if (trustManagerFactory != null) {
createAndInitSslContext(loadKeyManagerFactory(config), trustManagerFactory)
} else {
super.getSSLContext(config, additionalOpts)
}
}
override fun getPriority(): Int {
// We make sure this factory is the one that's chosen, so any sufficiently large value will do.
return 15
}
}
class NodeOpenSSLContextFactory : DefaultOpenSSLContextFactory() {
override fun getServerSslContext(config: SSLContextConfig, additionalOpts: Map<String, Any>): SslContext {
val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?
return if (trustManagerFactory != null) {
SslContextBuilder
.forServer(loadKeyManagerFactory(config))
.sslProvider(SslProvider.OPENSSL)
.trustManager(trustManagerFactory)
.build()
} else {
super.getServerSslContext(config, additionalOpts)
}
}
override fun getPriority(): Int {
// We make sure this factory is the one that's chosen, so any sufficiently large value will do.
return 15
}
}
private fun loadKeyManagerFactory(config: SSLContextConfig): KeyManagerFactory {
val keyStore = CertificateStore.fromFile(Paths.get(config.keystorePath), config.keystorePassword, config.keystorePassword, false)
return keyManagerFactory(keyStore)
}

View File

@ -0,0 +1 @@
net.corda.node.services.messaging.NodeOpenSSLContextFactory

View File

@ -0,0 +1 @@
net.corda.node.services.messaging.NodeSSLContextFactory