ENT-9806: Prevent Netty threads being blocked due to unresponsive CRL endpoints

This commit is contained in:
Shams Asari
2023-05-02 14:38:56 +01:00
parent 31a34e5a5c
commit 0a617097be
31 changed files with 1110 additions and 777 deletions

View File

@ -1,16 +1,18 @@
@file:Suppress("LongParameterList")
package net.corda.nodeapi.internal
import net.corda.core.messaging.ClientRpcSslOptions
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.BrokerRpcSslOptions
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.config.SslConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants
import java.nio.file.Path
import javax.net.ssl.TrustManagerFactory
@Suppress("LongParameterList")
class ArtemisTcpTransport {
@ -23,6 +25,7 @@ class ArtemisTcpTransport {
val TLS_VERSIONS = listOf("TLSv1.2")
const val SSL_HANDSHAKE_TIMEOUT_NAME = "Corda-SSLHandshakeTimeout"
const val TRUST_MANAGER_FACTORY_NAME = "Corda-TrustManagerFactory"
const val TRACE_NAME = "Corda-Trace"
const val THREAD_POOL_NAME_NAME = "Corda-ThreadPoolName"
@ -30,7 +33,6 @@ class ArtemisTcpTransport {
// Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop.
// It does not use AMQP messages for its own messages e.g. topology and heartbeats.
private const val P2P_PROTOCOLS = "CORE,AMQP"
private const val RPC_PROTOCOLS = "CORE"
private fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort, protocols: String) = mapOf(
@ -39,46 +41,35 @@ class ArtemisTcpTransport {
TransportConstants.PORT_PROP_NAME to hostAndPort.port,
TransportConstants.PROTOCOLS_PROP_NAME to protocols,
TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null),
TransportConstants.REMOTING_THREADS_PROPNAME to (if (nodeSerializationEnv != null) -1 else 1),
// turn off direct delivery in Artemis - this is latency optimisation that can lead to
//hick-ups under high load (CORDA-1336)
TransportConstants.DIRECT_DELIVER to false)
private val defaultSSLOptions = mapOf(
TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","),
TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to TLS_VERSIONS.joinToString(","))
private fun SslConfiguration.addToTransportOptions(options: MutableMap<String, Any>) {
if (keyStore != null || trustStore != null) {
options[TransportConstants.SSL_ENABLED_PROP_NAME] = true
options[TransportConstants.NEED_CLIENT_AUTH_PROP_NAME] = true
}
keyStore?.let {
with (it) {
path.requireOnDefaultFileSystem()
options.putAll(get().toKeyStoreTransportOptions(path))
options[TransportConstants.KEYSTORE_PROVIDER_PROP_NAME] = "JKS"
options[TransportConstants.KEYSTORE_PATH_PROP_NAME] = path
options[TransportConstants.KEYSTORE_PASSWORD_PROP_NAME] = get().password
}
}
trustStore?.let {
with (it) {
path.requireOnDefaultFileSystem()
options.putAll(get().toTrustStoreTransportOptions(path))
options[TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME] = "JKS"
options[TransportConstants.TRUSTSTORE_PATH_PROP_NAME] = path
options[TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME] = get().password
}
}
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
options[SSL_HANDSHAKE_TIMEOUT_NAME] = handshakeTimeout ?: DEFAULT_SSL_HANDSHAKE_TIMEOUT
}
private fun CertificateStore.toKeyStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.KEYSTORE_PROVIDER_PROP_NAME to "JKS",
TransportConstants.KEYSTORE_PATH_PROP_NAME to path,
TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun CertificateStore.toTrustStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME to "JKS",
TransportConstants.TRUSTSTORE_PATH_PROP_NAME to path,
TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun ClientRpcSslOptions.toTransportOptions() = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME to trustStoreProvider,
@ -94,50 +85,64 @@ class ArtemisTcpTransport {
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: MutualSslConfiguration?,
trustManagerFactory: TrustManagerFactory?,
enableSSL: Boolean = true,
threadPoolName: String = "P2PServer",
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (enableSSL) {
config?.addToTransportOptions(options)
}
return createAcceptorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace)
return createAcceptorTransport(
hostAndPort,
P2P_PROTOCOLS,
options,
trustManagerFactory,
enableSSL,
threadPoolName,
trace,
remotingThreads
)
}
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
config: MutualSslConfiguration?,
enableSSL: Boolean = true,
threadPoolName: String = "P2PClient",
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (enableSSL) {
config?.addToTransportOptions(options)
}
return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace)
return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace, remotingThreads)
}
fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: BrokerRpcSslOptions?,
enableSSL: Boolean = true,
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) {
config.keyStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions())
}
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCServer", trace)
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, null, enableSSL, "RPCServer", trace, remotingThreads)
}
fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
config: ClientRpcSslOptions?,
enableSSL: Boolean = true,
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) {
config.trustStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions())
}
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace, remotingThreads)
}
fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort,
@ -145,25 +150,45 @@ class ArtemisTcpTransport {
trace: Boolean = false): TransportConfiguration {
val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace, null)
}
fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: SslConfiguration,
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options)
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCServer", trace)
return createAcceptorTransport(
hostAndPort,
RPC_PROTOCOLS,
options,
trustManagerFactory(requireNotNull(config.trustStore).get()),
true,
"Internal-RPCServer",
trace,
remotingThreads
)
}
private fun createAcceptorTransport(hostAndPort: NetworkHostAndPort,
protocols: String,
options: MutableMap<String, Any>,
trustManagerFactory: TrustManagerFactory?,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean): TransportConfiguration {
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
// Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections
options[TransportConstants.HANDSHAKE_TIMEOUT] = 0
if (trustManagerFactory != null) {
// NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use
// more customised instances which use our revocation checkers, which we pass directly into NodeNettyAcceptorFactory.
//
// This, however, requires copying a lot of code from NettyAcceptor into NodeNettyAcceptor. The version of Artemis in
// Corda 4.9 solves this problem by introducing a "trustManagerFactoryPlugin" config option.
options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory
}
return createTransport(
"net.corda.node.services.messaging.NodeNettyAcceptorFactory",
hostAndPort,
@ -171,7 +196,8 @@ class ArtemisTcpTransport {
options,
enableSSL,
threadPoolName,
trace
trace,
remotingThreads
)
}
@ -180,7 +206,8 @@ class ArtemisTcpTransport {
options: MutableMap<String, Any>,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean): TransportConfiguration {
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
return createTransport(
"net.corda.node.services.messaging.NodeNettyConnectorFactory",
hostAndPort,
@ -188,7 +215,8 @@ class ArtemisTcpTransport {
options,
enableSSL,
threadPoolName,
trace
trace,
remotingThreads
)
}
@ -198,11 +226,15 @@ class ArtemisTcpTransport {
options: MutableMap<String, Any>,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean): TransportConfiguration {
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
options += defaultArtemisOptions(hostAndPort, protocols)
if (enableSSL) {
options += defaultSSLOptions
options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",")
options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",")
}
// By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357)
options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1
options[THREAD_POOL_NAME_NAME] = threadPoolName
options[TRACE_NAME] = trace
return TransportConfiguration(className, options)

View File

@ -0,0 +1,32 @@
@file:Suppress("LongParameterList", "MagicNumber")
package net.corda.nodeapi.internal
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.utilities.seconds
import java.time.Duration
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
/**
* Creates a [ThreadPoolExecutor] which will use a maximum of [maxPoolSize] threads at any given time and will by default idle down to 0
* threads.
*/
fun namedThreadPoolExecutor(maxPoolSize: Int,
corePoolSize: Int = 0,
idleKeepAlive: Duration = 30.seconds,
workQueue: BlockingQueue<Runnable> = LinkedBlockingQueue(),
poolName: String = "pool",
daemonThreads: Boolean = false,
threadPriority: Int = Thread.NORM_PRIORITY): ThreadPoolExecutor {
return ThreadPoolExecutor(
corePoolSize,
maxPoolSize,
idleKeepAlive.toNanos(),
TimeUnit.NANOSECONDS,
workQueue,
DefaultThreadFactory(poolName, daemonThreads, threadPriority)
)
}

View File

@ -100,7 +100,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private class AMQPBridge(val sourceX500Name: String,
val queueName: String,
val targets: List<NetworkHostAndPort>,
val legalNames: Set<CordaX500Name>,
val allowedRemoteLegalNames: Set<CordaX500Name>,
private val amqpConfig: AMQPConfiguration,
sharedEventGroup: EventLoopGroup,
private val artemis: ArtemisSessionProvider,
@ -116,7 +116,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
MDC.put("queueName", queueName)
MDC.put("source", amqpConfig.sourceX500Name)
MDC.put("targets", targets.joinToString(separator = ";") { it.toString() })
MDC.put("legalNames", legalNames.joinToString(separator = ";") { it.toString() })
MDC.put("allowedRemoteLegalNames", allowedRemoteLegalNames.joinToString(separator = ";") { it.toString() })
MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString())
block()
} finally {
@ -134,7 +134,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) }
val amqpClient = AMQPClient(targets, legalNames, amqpConfig, sharedThreadPool = sharedEventGroup)
val amqpClient = AMQPClient(targets, allowedRemoteLegalNames, amqpConfig, sharedThreadPool = sharedEventGroup)
private var session: ClientSession? = null
private var consumer: ClientConsumer? = null
private var connectedSubscription: Subscription? = null
@ -231,7 +231,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
}
ArtemisState.STOPPING
}
bridgeMetricsService?.bridgeDisconnected(targets, legalNames)
bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
connectedSubscription?.unsubscribe()
connectedSubscription = null
// Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first.
@ -243,7 +243,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
if (connected) {
logInfoWithMDC("Bridge Connected")
bridgeMetricsService?.bridgeConnected(targets, legalNames)
bridgeMetricsService?.bridgeConnected(targets, allowedRemoteLegalNames)
if (bridgeConnectionTTLSeconds > 0) {
// AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval
amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS,
@ -286,7 +286,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
logInfoWithMDC("Bridge Disconnected")
amqpRestartEvent?.cancel(false)
if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) {
bridgeMetricsService?.bridgeDisconnected(targets, legalNames)
bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
}
artemis(ArtemisState.STOPPING) { precedingState: ArtemisState ->
logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected")
@ -418,10 +418,10 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
properties[key] = value
}
}
logDebugWithMDC { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
logDebugWithMDC { "Bridged Send to ${allowedRemoteLegalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
val peerInbox = translateLocalQueueToInboxAddress(queueName)
val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox,
legalNames.first().toString(),
allowedRemoteLegalNames.first().toString(),
properties)
sendableMessage.onComplete.then {
logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" }
@ -486,7 +486,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
queueNamesToBridgesMap.remove(queueName)
}
bridge.stop()
bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.legalNames)
bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.allowedRemoteLegalNames)
}
}
}
@ -498,7 +498,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
val bridges = queueNamesToBridgesMap[queueName]?.toList()
destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList())
bridges?.map {
it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.legalNames.toList(), serviceAddress = false)
it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.allowedRemoteLegalNames.toList(), serviceAddress = false)
}?.toMap() ?: emptyMap()
}
}

View File

@ -5,16 +5,37 @@ package net.corda.nodeapi.internal.crypto
import net.corda.core.CordaOID
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.newSecureRandom
import net.corda.core.internal.*
import net.corda.core.internal.CertRole
import net.corda.core.internal.SignedDataWithCert
import net.corda.core.internal.reader
import net.corda.core.internal.signWithCert
import net.corda.core.internal.uncheckedCast
import net.corda.core.internal.validate
import net.corda.core.internal.writer
import net.corda.core.utilities.days
import net.corda.core.utilities.millis
import net.corda.core.utilities.toHex
import net.corda.nodeapi.internal.protonwrapper.netty.distributionPointsToString
import org.bouncycastle.asn1.*
import org.bouncycastle.asn1.ASN1EncodableVector
import org.bouncycastle.asn1.ASN1ObjectIdentifier
import org.bouncycastle.asn1.ASN1Sequence
import org.bouncycastle.asn1.DERSequence
import org.bouncycastle.asn1.DERUTF8String
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x509.*
import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.DistributionPoint
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.KeyPurposeId
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.asn1.x509.NameConstraints
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.X509v3CertificateBuilder
import org.bouncycastle.cert.bc.BcX509ExtensionUtils
@ -32,8 +53,13 @@ import java.nio.file.Path
import java.security.KeyPair
import java.security.PublicKey
import java.security.SignatureException
import java.security.cert.*
import java.security.cert.CertPath
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.CertificateFactory
import java.security.cert.TrustAnchor
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.time.Duration
import java.time.Instant
import java.time.temporal.ChronoUnit
@ -359,7 +385,7 @@ object X509Utilities {
private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) {
if (crlDistPoint != null) {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint)))
val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier))
val crlIssuerGeneralNames = crlIssuer?.let {
GeneralNames(GeneralName(crlIssuer))
}
@ -379,6 +405,8 @@ object X509Utilities {
bytes[0] = bytes[0].and(0x3F).or(0x40)
return BigInteger(bytes)
}
fun toGeneralNames(string: String, tag: Int = GeneralName.directoryName): GeneralNames = GeneralNames(GeneralName(tag, string))
}
// Assuming cert type to role is 1:1

View File

@ -27,15 +27,14 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME
import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import rx.Observable
import rx.subjects.PublishSubject
import java.lang.Long.min
import java.net.InetSocketAddress
import java.security.cert.CertPathValidatorException
import java.util.concurrent.ExecutorService
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock
enum class ProxyVersion {
@ -63,7 +62,8 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
val allowedRemoteLegalNames: Set<CordaX500Name>,
private val configuration: AMQPConfiguration,
private val sharedThreadPool: EventLoopGroup? = null,
private val threadPoolName: String = "AMQPClient") : AutoCloseable {
private val threadPoolName: String = "AMQPClient",
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON) : AutoCloseable {
companion object {
init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
@ -89,12 +89,12 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
private var targetIndex = 0
private var currentTarget: NetworkHostAndPort = targets.first()
private var retryInterval = MIN_RETRY_INTERVAL
private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker()
private val badCertTargets = mutableSetOf<NetworkHostAndPort>()
@Volatile
private var amqpActive = false
@Volatile
private var amqpChannelHandler: ChannelHandler? = null
private var sslDelegatedTaskExecutor: ExecutorService? = null
val localAddressString: String
get() = clientChannel?.localAddress()?.toString() ?: "<unknownLocalAddress>"
@ -150,17 +150,16 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
}
private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration
@Volatile
private lateinit var amqpChannelHandler: AMQPChannelHandler
init {
keyManagerFactory.init(conf.keyStore)
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker))
}
@Suppress("ComplexMethod")
override fun initChannel(ch: SocketChannel) {
val pipeline = ch.pipeline()
@ -199,10 +198,24 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration)
val target = parent.currentTarget
val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor)
val handler = if (parent.configuration.useOpenSsl) {
createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc())
createClientOpenSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
ch.alloc(),
delegatedTaskExecutor
)
} else {
createClientSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory)
createClientSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
delegatedTaskExecutor
)
}
handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis()
pipeline.addLast("sslHandler", handler)
@ -260,6 +273,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
return
}
log.info("Connect to: $currentTarget")
sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY))
started = true
restart()
@ -294,6 +308,8 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
}
clientChannel = null
workerGroup = null
sslDelegatedTaskExecutor?.shutdown()
sslDelegatedTaskExecutor = null
log.info("Stopped connection to $currentTarget")
}
}
@ -334,6 +350,4 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized()
val onConnection: Observable<ConnectionChange>
get() = _onConnection
val softFailExceptions: List<CertPathValidatorException> get() = revocationChecker.softFailExceptions
}
}

View File

@ -21,16 +21,15 @@ import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import org.apache.qpid.proton.engine.Delivery
import rx.Observable
import rx.subjects.PublishSubject
import java.net.BindException
import java.net.InetSocketAddress
import java.security.cert.CertPathValidatorException
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock
/**
@ -39,37 +38,34 @@ import kotlin.concurrent.withLock
class AMQPServer(val hostName: String,
val port: Int,
private val configuration: AMQPConfiguration,
private val threadPoolName: String = "AMQPServer") : AutoCloseable {
private val threadPoolName: String = "AMQPServer",
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON,
private val remotingThreads: Int? = null) : AutoCloseable {
companion object {
init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
}
private const val CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME = "net.corda.nodeapi.amqpserver.NumServerThreads"
private val log = contextLogger()
private val NUM_SERVER_THREADS = Integer.getInteger(CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME, 4)
private val DEFAULT_REMOTING_THREADS = Integer.getInteger("net.corda.nodeapi.amqpserver.NumServerThreads", 4)
}
private val lock = ReentrantLock()
@Volatile
private var stopping: Boolean = false
private var bossGroup: EventLoopGroup? = null
private var workerGroup: EventLoopGroup? = null
private var serverChannel: Channel? = null
private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker()
private var sslDelegatedTaskExecutor: ExecutorService? = null
private val clientChannels = ConcurrentHashMap<InetSocketAddress, SocketChannel>()
private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration
init {
keyManagerFactory.init(conf.keyStore.value.internal, conf.keyStore.entryPassword.toCharArray())
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker))
}
override fun initChannel(ch: SocketChannel) {
val amqpConfiguration = parent.configuration
val pipeline = ch.pipeline()
@ -116,11 +112,12 @@ class AMQPServer(val hostName: String,
Pair(createServerSNIOpenSniHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap)
} else {
val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig)
val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor)
val handler = if (amqpConfig.useOpenSsl) {
createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc())
createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc(), delegatedTaskExecutor)
} else {
// For javaSSL, SNI matching is handled at key manager level.
createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory)
createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory, delegatedTaskExecutor)
}
handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout.toMillis()
Pair(handler, mapOf(DEFAULT to keyManagerFactory))
@ -132,8 +129,13 @@ class AMQPServer(val hostName: String,
lock.withLock {
stop()
sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("$threadPoolName-boss", Thread.MAX_PRIORITY))
workerGroup = NioEventLoopGroup(NUM_SERVER_THREADS, DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY))
workerGroup = NioEventLoopGroup(
remotingThreads ?: DEFAULT_REMOTING_THREADS,
DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY)
)
val server = ServerBootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
@ -154,22 +156,19 @@ class AMQPServer(val hostName: String,
fun stop() {
lock.withLock {
try {
stopping = true
serverChannel?.apply { close() }
serverChannel = null
serverChannel?.close()
serverChannel = null
workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync()
workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync()
workerGroup = null
bossGroup?.shutdownGracefully()
bossGroup?.terminationFuture()?.sync()
bossGroup?.shutdownGracefully()
bossGroup?.terminationFuture()?.sync()
bossGroup = null
workerGroup = null
bossGroup = null
} finally {
stopping = false
}
sslDelegatedTaskExecutor?.shutdown()
sslDelegatedTaskExecutor = null
}
}
@ -226,6 +225,4 @@ class AMQPServer(val hostName: String,
private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized()
val onConnection: Observable<ConnectionChange>
get() = _onConnection
val softFailExceptions: List<CertPathValidatorException> get() = revocationChecker.softFailExceptions
}

View File

@ -31,4 +31,6 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() {
override fun getSoftFailExceptions(): List<CertPathValidatorException> {
return Collections.emptyList()
}
override fun clone(): AllowAllRevocationChecker = this
}

View File

@ -3,9 +3,6 @@ package net.corda.nodeapi.internal.protonwrapper.netty
import com.typesafe.config.Config
import net.corda.nodeapi.internal.config.ConfigParser
import net.corda.nodeapi.internal.config.CustomConfigParser
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import java.security.cert.PKIXRevocationChecker
/**
* Data structure for controlling the way how Certificate Revocation Lists are handled.
@ -45,18 +42,6 @@ interface RevocationConfig {
* Optional [CrlSource] which only makes sense with `mode` = `EXTERNAL_SOURCE`
*/
val externalCrlSource: CrlSource?
fun createPKIXRevocationChecker(): PKIXRevocationChecker {
return when (mode) {
Mode.OFF -> AllowAllRevocationChecker
Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(externalCrlSource) { "externalCrlSource must be specfied for EXTERNAL_SOURCE" }
CordaRevocationChecker(externalCrlSource, softFail = true)
}
Mode.SOFT_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = true)
Mode.HARD_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = false)
}
}
}
/**

View File

@ -1,3 +1,5 @@
@file:Suppress("ComplexMethod", "LongParameterList")
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.buffer.ByteBufAllocator
@ -18,6 +20,8 @@ import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.namedThreadPoolExecutor
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import org.bouncycastle.asn1.ASN1InputStream
import org.bouncycastle.asn1.ASN1Primitive
import org.bouncycastle.asn1.DERIA5String
@ -34,10 +38,10 @@ import java.net.URI
import java.security.KeyStore
import java.security.cert.CertificateException
import java.security.cert.PKIXBuilderParameters
import java.security.cert.PKIXRevocationChecker
import java.security.cert.X509CertSelector
import java.security.cert.X509Certificate
import java.util.concurrent.Executor
import java.util.concurrent.ThreadPoolExecutor
import javax.net.ssl.CertPathTrustManagerParameters
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName
@ -46,7 +50,6 @@ import javax.net.ssl.SSLEngine
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509ExtendedTrustManager
import javax.security.auth.x500.X500Principal
import kotlin.system.measureTimeMillis
private const val HOSTNAME_FORMAT = "%s.corda.net"
internal const val DEFAULT = "default"
@ -58,7 +61,6 @@ internal val logger = LoggerFactory.getLogger("net.corda.nodeapi.internal.proton
/**
* Returns all the CRL distribution points in the certificate as [URI]s along with the CRL issuer names, if any.
*/
@Suppress("ComplexMethod")
fun X509Certificate.distributionPoints(): Map<URI, List<X500Principal>?> {
logger.debug { "Checking CRLDPs for $subjectX500Principal" }
@ -117,6 +119,14 @@ fun certPathToString(certPath: Array<out X509Certificate>?): String {
return certPath.joinToString(System.lineSeparator()) { " ${it.toSimpleString()}" }
}
/**
* Create an executor for processing SSL handshake tasks asynchronously (see [SSLEngine.getDelegatedTask]). The max number of threads is 3,
* which is the typical number of CRLs expected in a Corda TLS cert path. The executor needs to be passed to the [SslHandler] constructor.
*/
fun sslDelegatedTaskExecutor(parentPoolName: String): ThreadPoolExecutor {
return namedThreadPoolExecutor(maxPoolSize = 3, poolName = "$parentPoolName-ssltask")
}
@VisibleForTesting
class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() {
companion object {
@ -179,32 +189,11 @@ class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509Ex
}
private object LoggingImmediateExecutor : Executor {
override fun execute(command: Runnable) {
val log = LoggerFactory.getLogger(javaClass)
@Suppress("TooGenericExceptionCaught", "MagicNumber") // log and rethrow all exceptions
try {
val commandName = command::class.qualifiedName?.let { "[$it]" } ?: ""
log.debug("Entering SSL command $commandName")
val elapsedTime = measureTimeMillis { command.run() }
log.debug("Exiting SSL command $elapsedTime millis")
if (elapsedTime > 100) {
log.info("Command: $commandName took $elapsedTime millis to execute")
}
}
catch (ex: Exception) {
log.error("Caught exception in SSL handler executor", ex)
throw ex
}
}
}
internal fun createClientSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine(target.host, target.port)
sslEngine.useClientMode = true
@ -216,14 +205,15 @@ internal fun createClientSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters
}
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler {
alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build()
val sslEngine = sslContext.newEngine(alloc, target.host, target.port)
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
@ -233,12 +223,13 @@ internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters
}
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
internal fun createServerSslHandler(keyStore: CertificateStore,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine()
sslEngine.useClientMode = false
@ -249,39 +240,29 @@ internal fun createServerSslHandler(keyStore: CertificateStore,
val sslParameters = sslEngine.sslParameters
sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore))
sslEngine.sslParameters = sslParameters
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler {
alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build()
val sslEngine = sslContext.newEngine(alloc)
sslEngine.useClientMode = false
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SSLContext {
fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory?): SSLContext {
val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java)
.map { LoggingTrustManagerWrapper(it) }.toTypedArray()
sslContext.init(keyManagers, trustManagers, newSecureRandom())
val trustManagers = trustManagerFactory
?.trustManagers
?.map { if (it is X509ExtendedTrustManager) LoggingTrustManagerWrapper(it) else it }
?.toTypedArray()
sslContext.init(keyManagerFactory.keyManagers, trustManagers, newSecureRandom())
return sslContext
}
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore,
revocationConfig: RevocationConfig): CertPathTrustManagerParameters {
return initialiseTrustStoreAndEnableCrlChecking(trustStore, revocationConfig.createPKIXRevocationChecker())
}
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore,
revocationChecker: PKIXRevocationChecker): CertPathTrustManagerParameters {
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
return CertPathTrustManagerParameters(pkixParams)
}
/**
* Creates a special SNI handler used only when openSSL is used for AMQPServer
*/
@ -296,14 +277,13 @@ internal fun createServerSNIOpenSniHandler(keyManagerFactoriesMap: Map<String, K
return SniHandler(mapping.build())
}
@Suppress("SpreadOperator")
private fun getServerSslContextBuilder(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslContextBuilder {
return SslContextBuilder.forServer(keyManagerFactory)
.sslProvider(SslProvider.OPENSSL)
.trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory))
.clientAuth(ClientAuth.REQUIRE)
.ciphers(ArtemisTcpTransport.CIPHER_SUITES)
.protocols(*ArtemisTcpTransport.TLS_VERSIONS.toTypedArray())
.protocols(ArtemisTcpTransport.TLS_VERSIONS)
}
internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKeyManagerFactoryWrapper> {
@ -325,9 +305,38 @@ internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKe
// As per Javadoc in: https://docs.oracle.com/javase/8/docs/api/javax/net/ssl/KeyManagerFactory.html `init` method
// 2nd parameter `password` - the password for recovering keys in the KeyStore
fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.entryPassword.toCharArray())
fun keyManagerFactory(keyStore: CertificateStore): KeyManagerFactory {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore.value.internal, keyStore.entryPassword.toCharArray())
return keyManagerFactory
}
fun TrustManagerFactory.init(trustStore: CertificateStore) = init(trustStore.value.internal)
fun trustManagerFactory(trustStore: CertificateStore): TrustManagerFactory {
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore.value.internal)
return trustManagerFactory
}
fun trustManagerFactoryWithRevocation(trustStore: CertificateStore,
revocationConfig: RevocationConfig,
crlSource: CrlSource): TrustManagerFactory {
val revocationChecker = when (revocationConfig.mode) {
RevocationConfig.Mode.OFF -> AllowAllRevocationChecker
RevocationConfig.Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(revocationConfig.externalCrlSource) {
"externalCrlSource must be specfied for EXTERNAL_SOURCE"
}
CordaRevocationChecker(externalCrlSource, softFail = true)
}
RevocationConfig.Mode.SOFT_FAIL -> CordaRevocationChecker(crlSource, softFail = true)
RevocationConfig.Mode.HARD_FAIL -> CordaRevocationChecker(crlSource, softFail = false)
}
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(CertPathTrustManagerParameters(pkixParams))
return trustManagerFactory
}
/**
* Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target

View File

@ -5,6 +5,9 @@ import com.github.benmanes.caffeine.cache.LoadingCache
import net.corda.core.internal.readFully
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
@ -12,60 +15,71 @@ import net.corda.nodeapi.internal.protonwrapper.netty.distributionPoints
import java.net.URI
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit
import java.time.Duration
import javax.security.auth.x500.X500Principal
/**
* [CrlSource] which downloads CRLs from the distribution points in the X509 certificate.
* [CrlSource] which downloads CRLs from the distribution points in the X509 certificate and caches them.
*/
@Suppress("TooGenericExceptionCaught")
class CertDistPointCrlSource : CrlSource {
class CertDistPointCrlSource(cacheSize: Long = DEFAULT_CACHE_SIZE,
cacheExpiry: Duration = DEFAULT_CACHE_EXPIRY,
private val connectTimeout: Duration = DEFAULT_CONNECT_TIMEOUT,
private val readTimeout: Duration = DEFAULT_READ_TIMEOUT) : CrlSource {
companion object {
private val logger = contextLogger()
// The default SSL handshake timeout is 60s (DEFAULT_SSL_HANDSHAKE_TIMEOUT). Considering there are 3 CRLs endpoints to check in a
// node handshake, we want to keep the total timeout within that.
private const val DEFAULT_CONNECT_TIMEOUT = 9_000
private const val DEFAULT_READ_TIMEOUT = 9_000
private val DEFAULT_CONNECT_TIMEOUT = 9.seconds
private val DEFAULT_READ_TIMEOUT = 9.seconds
private const val DEFAULT_CACHE_SIZE = 185L // Same default as the JDK (URICertStore)
private const val DEFAULT_CACHE_EXPIRY = 5 * 60 * 1000L
private val DEFAULT_CACHE_EXPIRY = 5.minutes
private val cache: LoadingCache<URI, X509CRL> = Caffeine.newBuilder()
.maximumSize(java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE))
.expireAfterWrite(java.lang.Long.getLong("net.corda.dpcrl.cache.expiry", DEFAULT_CACHE_EXPIRY), TimeUnit.MILLISECONDS)
.build(::retrieveCRL)
val SINGLETON = CertDistPointCrlSource(
cacheSize = java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE),
cacheExpiry = java.lang.Long.getLong("net.corda.dpcrl.cache.expiry")?.let(Duration::ofMillis) ?: DEFAULT_CACHE_EXPIRY,
connectTimeout = java.lang.Long.getLong("net.corda.dpcrl.connect.timeout")?.let(Duration::ofMillis) ?: DEFAULT_CONNECT_TIMEOUT,
readTimeout = java.lang.Long.getLong("net.corda.dpcrl.read.timeout")?.let(Duration::ofMillis) ?: DEFAULT_READ_TIMEOUT
)
}
private val connectTimeout = Integer.getInteger("net.corda.dpcrl.connect.timeout", DEFAULT_CONNECT_TIMEOUT)
private val readTimeout = Integer.getInteger("net.corda.dpcrl.read.timeout", DEFAULT_READ_TIMEOUT)
private val cache: LoadingCache<URI, X509CRL> = Caffeine.newBuilder()
.maximumSize(cacheSize)
.expireAfterWrite(cacheExpiry)
.build(::retrieveCRL)
private fun retrieveCRL(uri: URI): X509CRL {
val start = System.currentTimeMillis()
val bytes = try {
val conn = uri.toURL().openConnection()
conn.connectTimeout = connectTimeout
conn.readTimeout = readTimeout
// Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes
// in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException
// into CRLException and drops the cause chain.
conn.getInputStream().readFully()
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e)
}
throw e
private fun retrieveCRL(uri: URI): X509CRL {
val start = System.currentTimeMillis()
val bytes = try {
val conn = uri.toURL().openConnection()
conn.connectTimeout = connectTimeout.toMillis().toInt()
conn.readTimeout = readTimeout.toMillis().toInt()
// Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes
// in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException
// into CRLException and drops the cause chain.
conn.getInputStream().readFully()
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e)
}
val duration = System.currentTimeMillis() - start
val crl = try {
X509CertificateFactory().generateCRL(bytes.inputStream())
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Invalid CRL from $uri (${duration}ms)", e)
}
throw e
}
logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" }
return crl
throw e
}
val duration = System.currentTimeMillis() - start
val crl = try {
X509CertificateFactory().generateCRL(bytes.inputStream())
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Invalid CRL from $uri (${duration}ms)", e)
}
throw e
}
logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" }
return crl
}
fun clearCache() {
cache.invalidateAll()
}
override fun fetch(certificate: X509Certificate): Set<X509CRL> {

View File

@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.assertj.core.api.Assertions
import org.junit.Rule
import org.junit.Test
@ -161,11 +162,9 @@ class TlsDiffAlgorithmsTest(private val serverAlgo: String, private val clientAl
private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext {
return SSLContext.getInstance("TLS").apply {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagerFactory = keyManagerFactory(keyStore)
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustMgrFactory = trustManagerFactory(trustStore)
val trustManagers = trustMgrFactory.trustManagers
init(keyManagers, trustManagers, newSecureRandom())
}

View File

@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.assertj.core.api.Assertions
import org.junit.Ignore
import org.junit.Rule
@ -18,7 +19,6 @@ import java.io.IOException
import java.net.InetAddress
import java.net.InetSocketAddress
import javax.net.ssl.*
import javax.net.ssl.SNIHostName
import kotlin.concurrent.thread
import kotlin.test.assertEquals
import kotlin.test.assertFalse
@ -209,11 +209,9 @@ class TlsDiffProtocolsTest(private val serverAlgo: String, private val clientAlg
private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext {
return SSLContext.getInstance("TLS").apply {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagerFactory = keyManagerFactory(keyStore)
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustMgrFactory = trustManagerFactory(trustStore)
val trustManagers = trustMgrFactory.trustManagers
init(keyManagers, trustManagers, newSecureRandom())
}

View File

@ -1,5 +1,6 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.util.concurrent.ImmediateExecutor
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.NetworkHostAndPort
@ -8,10 +9,9 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
import net.corda.testing.internal.fixedCrlSource
import org.junit.Test
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName
import javax.net.ssl.TrustManagerFactory
import kotlin.test.assertEquals
class SSLHelperTest {
@ -20,15 +20,21 @@ class SSLHelperTest {
val legalName = CordaX500Name("Test", "London", "GB")
val sslConfig = configureTestSSL(legalName)
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val keyStore = sslConfig.keyStore
keyManagerFactory.init(CertificateStore.fromFile(keyStore.path, keyStore.storePassword, keyStore.entryPassword, false))
val trustStore = sslConfig.trustStore
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(CertificateStore.fromFile(trustStore.path, trustStore.storePassword, trustStore.entryPassword, false), RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL)))
val trustManagerFactory = trustManagerFactoryWithRevocation(
sslConfig.trustStore.get(),
RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL),
fixedCrlSource(emptySet())
)
val sslHandler = createClientSslHandler(NetworkHostAndPort("localhost", 1234), setOf(legalName), keyManagerFactory, trustManagerFactory)
val sslHandler = createClientSslHandler(
NetworkHostAndPort("localhost", 1234),
setOf(legalName),
keyManagerFactory,
trustManagerFactory,
ImmediateExecutor.INSTANCE
)
val legalNameHash = SecureHash.sha256(legalName.toString()).toString().take(32).toLowerCase()
// These hardcoded values must not be changed, something is broken if you have to change these hardcoded values.

View File

@ -2,15 +2,13 @@ package net.corda.nodeapi.internal.revocation
import net.corda.core.crypto.Crypto
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.testing.core.ALICE_NAME
import net.corda.nodeapi.internal.DEV_INTERMEDIATE_CA
import net.corda.testing.node.internal.network.CrlServer
import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.math.BigInteger
class CertDistPointCrlSourceTest {
private lateinit var crlServer: CrlServer
@ -39,13 +37,14 @@ class CertDistPointCrlSourceTest {
assertThat(single().revokedCertificates).isNull()
}
val nodeCaCert = crlServer.replaceNodeCertDistPoint(createDevNodeCa(crlServer.intermediateCa, ALICE_NAME).certificate)
crlSource.clearCache()
crlServer.revokedNodeCerts += listOf(BigInteger.ONE, BigInteger.TEN)
with(crlSource.fetch(nodeCaCert)) { // Use a different cert to avoid the cache
crlServer.revokedIntermediateCerts += DEV_INTERMEDIATE_CA.certificate
with(crlSource.fetch(crlServer.intermediateCa.certificate)) {
assertThat(size).isEqualTo(1)
val revokedCertificates = single().revokedCertificates
assertThat(revokedCertificates.map { it.serialNumber }).containsExactlyInAnyOrder(BigInteger.ONE, BigInteger.TEN)
// This also tests clearCache() works.
assertThat(revokedCertificates.map { it.serialNumber }).containsExactly(DEV_INTERMEDIATE_CA.certificate.serialNumber)
}
}
}

View File

@ -5,7 +5,7 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
import net.corda.testing.internal.fixedCrlSource
import org.bouncycastle.jcajce.provider.asymmetric.x509.CertificateFactory
import org.junit.Test
import java.math.BigInteger
@ -41,10 +41,8 @@ class CordaRevocationCheckerTest {
val resourceAsStream = javaClass.getResourceAsStream("/net/corda/nodeapi/internal/protonwrapper/netty/doorman.crl")
val crl = CertificateFactory().engineGenerateCRL(resourceAsStream) as X509CRL
val crlSource = object : CrlSource {
override fun fetch(certificate: X509Certificate): Set<X509CRL> = setOf(crl)
}
val checker = CordaRevocationChecker(crlSource,
val checker = CordaRevocationChecker(
crlSource = fixedCrlSource(setOf(crl)),
softFail = true,
dateSource = { Date.from(date.atStartOfDay().toInstant(ZoneOffset.UTC)) }
)

View File

@ -0,0 +1,203 @@
package net.corda.nodeapi.internal.revocation
import net.corda.core.crypto.Crypto
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509KeyStore
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.testing.core.createCRL
import org.bouncycastle.asn1.x500.X500Name
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import java.io.File
import java.security.KeyPair
import java.security.KeyStore
import java.security.PrivateKey
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.*
import javax.net.ssl.X509TrustManager
import javax.security.auth.x500.X500Principal
import kotlin.test.assertFailsWith
@RunWith(Parameterized::class)
class RevocationTest(private val revocationMode: RevocationConfig.Mode) {
companion object {
@JvmStatic
@Parameterized.Parameters(name = "revocationMode = {0}")
fun data() = listOf(RevocationConfig.Mode.OFF, RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
}
@Rule
@JvmField
val tempFolder = TemporaryFolder()
private lateinit var rootCRL: File
private lateinit var doormanCRL: File
private lateinit var tlsCRL: File
private lateinit var trustManager: X509TrustManager
private val rootKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private val tlsCRLIssuerKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private val doormanKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private val nodeCAKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private val tlsKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private lateinit var rootCert: X509Certificate
private lateinit var tlsCRLIssuerCert: X509Certificate
private lateinit var doormanCert: X509Certificate
private lateinit var nodeCACert: X509Certificate
private lateinit var tlsCert: X509Certificate
private val chain
get() = arrayOf(tlsCert, nodeCACert, doormanCert, rootCert)
@Before
fun before() {
rootCRL = tempFolder.newFile("root.crl")
doormanCRL = tempFolder.newFile("doorman.crl")
tlsCRL = tempFolder.newFile("tls.crl")
rootCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=root"), rootKeyPair)
tlsCRLIssuerCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=issuer"), tlsCRLIssuerKeyPair)
val trustStore = KeyStore.getInstance("JKS")
trustStore.load(null, null)
trustStore.setCertificateEntry("cordatlscrlsigner", tlsCRLIssuerCert)
trustStore.setCertificateEntry("cordarootca", rootCert)
val trustManagerFactory = trustManagerFactoryWithRevocation(
CertificateStore.of(X509KeyStore(trustStore, "pass"), "pass", "pass"),
RevocationConfigImpl(revocationMode),
CertDistPointCrlSource()
)
trustManager = trustManagerFactory.trustManagers.single() as X509TrustManager
doormanCert = X509Utilities.createCertificate(
CertificateType.INTERMEDIATE_CA, rootCert, rootKeyPair, X500Principal("CN=doorman"), doormanKeyPair.public,
crlDistPoint = rootCRL.toURI().toString()
)
nodeCACert = X509Utilities.createCertificate(
CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=node"), nodeCAKeyPair.public,
crlDistPoint = doormanCRL.toURI().toString()
)
tlsCert = X509Utilities.createCertificate(
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=tls"), tlsKeyPair.public,
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
)
rootCRL.writeCRL(rootCert, rootKeyPair.private, false)
doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false)
tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true)
}
private fun File.writeCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) {
val crl = createCRL(
CertificateAndKeyPair(certificate, KeyPair(certificate.publicKey, privateKey)),
revoked.asList(),
indirect = indirect
)
writeBytes(crl.encoded)
}
private fun assertFailsFor(vararg modes: RevocationConfig.Mode) {
if (revocationMode in modes) assertFailsWith(CertificateException::class, ::doRevocationCheck) else doRevocationCheck()
}
@Test(timeout = 300_000)
fun `ok with empty CRLs`() {
doRevocationCheck()
}
@Test(timeout = 300_000)
fun `soft fail with revoked TLS certificate`() {
tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert)
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
fun `hard fail with unavailable CRL in TLS certificate`() {
tlsCert = X509Utilities.createCertificate(
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=tls"), tlsKeyPair.public,
crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/tls",
crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
fun `hard fail with invalid CRL issuer in TLS certificate`() {
tlsCert = X509Utilities.createCertificate(
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=tls"), tlsKeyPair.public,
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name("CN=unknown")
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
fun `hard fail without CRL issuer in TLS certificate`() {
tlsCert = X509Utilities.createCertificate(
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=tls"), tlsKeyPair.public,
crlDistPoint = tlsCRL.toURI().toString()
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
fun `ok with other certificate in TLS CRL`() {
val otherKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
val otherCert = X509Utilities.createCertificate(
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=other"), otherKeyPair.public,
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
)
tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert)
doRevocationCheck()
}
@Test(timeout = 300_000)
fun `soft fail with revoked node CA certificate`() {
doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, nodeCACert)
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
fun `hard fail with unavailable CRL in node CA certificate`() {
nodeCACert = X509Utilities.createCertificate(
CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=node"), nodeCAKeyPair.public,
crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/doorman"
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
fun `ok with other certificate in doorman CRL`() {
val otherKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
val otherCert = X509Utilities.createCertificate(
CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=other"), otherKeyPair.public,
crlDistPoint = doormanCRL.toURI().toString()
)
doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, otherCert)
doRevocationCheck()
}
private fun doRevocationCheck() {
trustManager.checkClientTrusted(chain, "ECDHE_ECDSA")
}
}