Merge branch 'release/os/4.9' into shams-4.10-merge-e6a80822

# Conflicts:
#	.github/workflows/check-pr-title.yml
#	.snyk
#	node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt
#	node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt
#	node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
This commit is contained in:
Shams Asari
2023-07-13 10:53:30 +01:00
67 changed files with 1687 additions and 1204 deletions

View File

@ -42,8 +42,8 @@ class ArtemisMessagingClient(private val config: MutualSslConfiguration,
override fun start(): Started = synchronized(this) {
check(started == null) { "start can't be called twice" }
val tcpTransport = p2pConnectorTcpTransport(serverAddress, config, threadPoolName = threadPoolName, trace = trace)
val backupTransports = backupServerAddressPool.map {
p2pConnectorTcpTransport(it, config, threadPoolName = threadPoolName, trace = trace)
val backupTransports = backupServerAddressPool.mapIndexed { index, address ->
p2pConnectorTcpTransport(address, config, threadPoolName = "$threadPoolName-backup${index+1}", trace = trace)
}
log.info("Connecting to message broker: $serverAddress")

View File

@ -1,16 +1,18 @@
@file:Suppress("LongParameterList")
package net.corda.nodeapi.internal
import net.corda.core.messaging.ClientRpcSslOptions
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.BrokerRpcSslOptions
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.config.SslConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants
import java.nio.file.Path
import javax.net.ssl.TrustManagerFactory
@Suppress("LongParameterList")
class ArtemisTcpTransport {
@ -23,6 +25,7 @@ class ArtemisTcpTransport {
val TLS_VERSIONS = listOf("TLSv1.2")
const val SSL_HANDSHAKE_TIMEOUT_NAME = "Corda-SSLHandshakeTimeout"
const val TRUST_MANAGER_FACTORY_NAME = "Corda-TrustManagerFactory"
const val TRACE_NAME = "Corda-Trace"
const val THREAD_POOL_NAME_NAME = "Corda-ThreadPoolName"
@ -30,7 +33,6 @@ class ArtemisTcpTransport {
// Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop.
// It does not use AMQP messages for its own messages e.g. topology and heartbeats.
private const val P2P_PROTOCOLS = "CORE,AMQP"
private const val RPC_PROTOCOLS = "CORE"
private fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort, protocols: String) = mapOf(
@ -39,46 +41,35 @@ class ArtemisTcpTransport {
TransportConstants.PORT_PROP_NAME to hostAndPort.port,
TransportConstants.PROTOCOLS_PROP_NAME to protocols,
TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null),
TransportConstants.REMOTING_THREADS_PROPNAME to (if (nodeSerializationEnv != null) -1 else 1),
// turn off direct delivery in Artemis - this is latency optimisation that can lead to
//hick-ups under high load (CORDA-1336)
TransportConstants.DIRECT_DELIVER to false)
private val defaultSSLOptions = mapOf(
TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","),
TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to TLS_VERSIONS.joinToString(","))
private fun SslConfiguration.addToTransportOptions(options: MutableMap<String, Any>) {
if (keyStore != null || trustStore != null) {
options[TransportConstants.SSL_ENABLED_PROP_NAME] = true
options[TransportConstants.NEED_CLIENT_AUTH_PROP_NAME] = true
}
keyStore?.let {
with (it) {
path.requireOnDefaultFileSystem()
options.putAll(get().toKeyStoreTransportOptions(path))
options[TransportConstants.KEYSTORE_TYPE_PROP_NAME] = "JKS"
options[TransportConstants.KEYSTORE_PATH_PROP_NAME] = path
options[TransportConstants.KEYSTORE_PASSWORD_PROP_NAME] = get().password
}
}
trustStore?.let {
with (it) {
path.requireOnDefaultFileSystem()
options.putAll(get().toTrustStoreTransportOptions(path))
options[TransportConstants.TRUSTSTORE_TYPE_PROP_NAME] = "JKS"
options[TransportConstants.TRUSTSTORE_PATH_PROP_NAME] = path
options[TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME] = get().password
}
}
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
options[SSL_HANDSHAKE_TIMEOUT_NAME] = handshakeTimeout ?: DEFAULT_SSL_HANDSHAKE_TIMEOUT
}
private fun CertificateStore.toKeyStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.KEYSTORE_TYPE_PROP_NAME to "JKS",
TransportConstants.KEYSTORE_PATH_PROP_NAME to path,
TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun CertificateStore.toTrustStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to "JKS",
TransportConstants.TRUSTSTORE_PATH_PROP_NAME to path,
TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun ClientRpcSslOptions.toTransportOptions() = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to trustStoreProvider,
@ -94,76 +85,110 @@ class ArtemisTcpTransport {
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: MutualSslConfiguration?,
trustManagerFactory: TrustManagerFactory? = config?.trustStore?.get()?.let(::trustManagerFactory),
enableSSL: Boolean = true,
threadPoolName: String = "P2PServer",
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (enableSSL) {
config?.addToTransportOptions(options)
}
return createAcceptorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace)
return createAcceptorTransport(
hostAndPort,
P2P_PROTOCOLS,
options,
trustManagerFactory,
enableSSL,
threadPoolName,
trace,
remotingThreads
)
}
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
config: MutualSslConfiguration?,
enableSSL: Boolean = true,
threadPoolName: String = "P2PClient",
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (enableSSL) {
config?.addToTransportOptions(options)
}
return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace)
return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace, remotingThreads)
}
fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: BrokerRpcSslOptions?,
enableSSL: Boolean = true,
trace: Boolean = false): TransportConfiguration {
threadPoolName: String = "RPCServer",
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) {
config.keyStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions())
}
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCServer", trace)
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, null, enableSSL, threadPoolName, trace, remotingThreads)
}
fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
config: ClientRpcSslOptions?,
enableSSL: Boolean = true,
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) {
config.trustStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions())
}
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace, remotingThreads)
}
fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort,
config: SslConfiguration,
threadPoolName: String = "Internal-RPCClient",
trace: Boolean = false): TransportConfiguration {
val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, threadPoolName, trace, null)
}
fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: SslConfiguration,
trace: Boolean = false): TransportConfiguration {
threadPoolName: String = "Internal-RPCServer",
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options)
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCServer", trace)
return createAcceptorTransport(
hostAndPort,
RPC_PROTOCOLS,
options,
trustManagerFactory(requireNotNull(config.trustStore).get()),
true,
threadPoolName,
trace,
remotingThreads
)
}
private fun createAcceptorTransport(hostAndPort: NetworkHostAndPort,
protocols: String,
options: MutableMap<String, Any>,
trustManagerFactory: TrustManagerFactory?,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean): TransportConfiguration {
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
// Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections
options[TransportConstants.HANDSHAKE_TIMEOUT] = 0
if (trustManagerFactory != null) {
// NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use
// more customised instances which use our revocation checkers, so we pass them in, to be picked up by Node(Open)SSLContextFactory.
options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory
}
return createTransport(
"net.corda.node.services.messaging.NodeNettyAcceptorFactory",
hostAndPort,
@ -171,7 +196,8 @@ class ArtemisTcpTransport {
options,
enableSSL,
threadPoolName,
trace
trace,
remotingThreads
)
}
@ -180,15 +206,21 @@ class ArtemisTcpTransport {
options: MutableMap<String, Any>,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean): TransportConfiguration {
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
if (enableSSL) {
// This is required to stop Client checking URL address vs. Server provided certificate
options[TransportConstants.VERIFY_HOST_PROP_NAME] = false
}
return createTransport(
NodeNettyConnectorFactory::class.java.name,
CordaNettyConnectorFactory::class.java.name,
hostAndPort,
protocols,
options,
enableSSL,
threadPoolName,
trace
trace,
remotingThreads
)
}
@ -198,13 +230,15 @@ class ArtemisTcpTransport {
options: MutableMap<String, Any>,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean): TransportConfiguration {
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
options += defaultArtemisOptions(hostAndPort, protocols)
if (enableSSL) {
options += defaultSSLOptions
// This is required to stop Client checking URL address vs. Server provided certificate
options[TransportConstants.VERIFY_HOST_PROP_NAME] = false
options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",")
options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",")
}
// By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357)
options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1
options[THREAD_POOL_NAME_NAME] = threadPoolName
options[TRACE_NAME] = trace
return TransportConfiguration(className, options)

View File

@ -1,8 +1,14 @@
@file:JvmName("ArtemisUtils")
package net.corda.nodeapi.internal
import net.corda.core.internal.declaredField
import org.apache.activemq.artemis.utils.actors.ProcessorBase
import java.nio.file.FileSystems
import java.nio.file.Path
import java.util.concurrent.Executor
import java.util.concurrent.ThreadFactory
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.atomic.AtomicInteger
/**
* Require that the [Path] is on a default file system, and therefore is one that Artemis is willing to use.
@ -16,3 +22,29 @@ fun requireMessageSize(messageSize: Int, limit: Int) {
require(messageSize <= limit) { "Message exceeds maxMessageSize network parameter, maxMessageSize: [$limit], message size: [$messageSize]" }
}
val Executor.rootExecutor: Executor get() {
var executor: Executor = this
while (executor is ProcessorBase<*>) {
executor = executor.declaredField<Executor>("delegate").value
}
return executor
}
fun Executor.setThreadPoolName(threadPoolName: String) {
(rootExecutor as? ThreadPoolExecutor)?.let { it.threadFactory = NamedThreadFactory(threadPoolName, it.threadFactory) }
}
private class NamedThreadFactory(poolName: String, private val delegate: ThreadFactory) : ThreadFactory {
companion object {
private val poolId = AtomicInteger(0)
}
private val prefix = "$poolName-${poolId.incrementAndGet()}-"
private val nextId = AtomicInteger(0)
override fun newThread(r: Runnable): Thread {
val thread = delegate.newThread(r)
thread.name = "$prefix${nextId.incrementAndGet()}"
return thread
}
}

View File

@ -14,15 +14,16 @@ import org.apache.activemq.artemis.utils.ConfigurationHelper
import java.util.concurrent.Executor
import java.util.concurrent.ScheduledExecutorService
class NodeNettyConnectorFactory : ConnectorFactory {
class CordaNettyConnectorFactory : ConnectorFactory {
override fun createConnector(configuration: MutableMap<String, Any>?,
handler: BufferHandler?,
listener: ClientConnectionLifeCycleListener?,
closeExecutor: Executor?,
threadPool: Executor?,
scheduledThreadPool: ScheduledExecutorService?,
closeExecutor: Executor,
threadPool: Executor,
scheduledThreadPool: ScheduledExecutorService,
protocolManager: ClientProtocolManager?): Connector {
val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Connector", configuration)
setThreadPoolName(threadPool, closeExecutor, scheduledThreadPool, threadPoolName)
val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration)
return NettyConnector(
configuration,
@ -31,7 +32,7 @@ class NodeNettyConnectorFactory : ConnectorFactory {
closeExecutor,
threadPool,
scheduledThreadPool,
MyClientProtocolManager(threadPoolName, trace)
MyClientProtocolManager("$threadPoolName-netty", trace)
)
}
@ -39,6 +40,17 @@ class NodeNettyConnectorFactory : ConnectorFactory {
override fun getDefaults(): Map<String?, Any?> = NettyConnector.DEFAULT_CONFIG
private fun setThreadPoolName(threadPool: Executor, closeExecutor: Executor, scheduledThreadPool: ScheduledExecutorService, name: String) {
threadPool.setThreadPoolName("$name-artemis")
// Artemis will actually wrap the same backing Executor to create multiple "OrderedExecutors". In this scenerio both the threadPool
// and the closeExecutor are the same when it comes to the pool names. If however they are different then given them separate names.
if (threadPool.rootExecutor !== closeExecutor.rootExecutor) {
closeExecutor.setThreadPoolName("$name-artemis-closer")
}
// The scheduler is separate
scheduledThreadPool.setThreadPoolName("$name-artemis-scheduler")
}
private class MyClientProtocolManager(private val threadPoolName: String, private val trace: Boolean) : ActiveMQClientProtocolManager() {
override fun addChannelHandlers(pipeline: ChannelPipeline) {

View File

@ -0,0 +1,32 @@
@file:Suppress("LongParameterList", "MagicNumber")
package net.corda.nodeapi.internal
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.utilities.seconds
import java.time.Duration
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
/**
* Creates a [ThreadPoolExecutor] which will use a maximum of [maxPoolSize] threads at any given time and will by default idle down to 0
* threads.
*/
fun namedThreadPoolExecutor(maxPoolSize: Int,
corePoolSize: Int = 0,
idleKeepAlive: Duration = 30.seconds,
workQueue: BlockingQueue<Runnable> = LinkedBlockingQueue(),
poolName: String = "pool",
daemonThreads: Boolean = false,
threadPriority: Int = Thread.NORM_PRIORITY): ThreadPoolExecutor {
return ThreadPoolExecutor(
corePoolSize,
maxPoolSize,
idleKeepAlive.toNanos(),
TimeUnit.NANOSECONDS,
workQueue,
DefaultThreadFactory(poolName, daemonThreads, threadPriority)
)
}

View File

@ -22,6 +22,7 @@ import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
@ -31,6 +32,7 @@ import org.apache.activemq.artemis.api.core.client.ClientSession
import org.slf4j.MDC
import rx.Subscription
import java.time.Duration
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.ScheduledFuture
@ -53,7 +55,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
maxMessageSize: Int,
revocationConfig: RevocationConfig,
enableSNI: Boolean,
private val artemisMessageClientFactory: () -> ArtemisSessionProvider,
private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider,
private val bridgeMetricsService: BridgeMetricsService? = null,
trace: Boolean,
sslHandshakeTimeout: Duration?,
@ -78,9 +80,11 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig,useOpenSSL, enableSNI, trace = trace, _sslHandshakeTimeout = sslHandshakeTimeout)
private var sharedEventLoopGroup: EventLoopGroup? = null
private var sslDelegatedTaskExecutor: ExecutorService? = null
private var artemis: ArtemisSessionProvider? = null
companion object {
private val log = contextLogger()
private const val CORDA_NUM_BRIDGE_THREADS_PROP_NAME = "net.corda.nodeapi.amqpbridgemanager.NumBridgeThreads"
@ -97,18 +101,11 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
* however Artemis and the remote Corda instanced will deduplicate these messages.
*/
@Suppress("TooManyFunctions")
private class AMQPBridge(val sourceX500Name: String,
val queueName: String,
val targets: List<NetworkHostAndPort>,
val legalNames: Set<CordaX500Name>,
private val amqpConfig: AMQPConfiguration,
sharedEventGroup: EventLoopGroup,
private val artemis: ArtemisSessionProvider,
private val bridgeMetricsService: BridgeMetricsService?,
private val bridgeConnectionTTLSeconds: Int) {
companion object {
private val log = contextLogger()
}
private inner class AMQPBridge(val sourceX500Name: String,
val queueName: String,
val targets: List<NetworkHostAndPort>,
val allowedRemoteLegalNames: Set<CordaX500Name>,
private val amqpConfig: AMQPConfiguration) {
private fun withMDC(block: () -> Unit) {
val oldMDC = MDC.getCopyOfContextMap() ?: emptyMap<String, String>()
@ -116,7 +113,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
MDC.put("queueName", queueName)
MDC.put("source", amqpConfig.sourceX500Name)
MDC.put("targets", targets.joinToString(separator = ";") { it.toString() })
MDC.put("legalNames", legalNames.joinToString(separator = ";") { it.toString() })
MDC.put("allowedRemoteLegalNames", allowedRemoteLegalNames.joinToString(separator = ";") { it.toString() })
MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString())
block()
} finally {
@ -134,13 +131,18 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) }
val amqpClient = AMQPClient(targets, legalNames, amqpConfig, sharedThreadPool = sharedEventGroup)
val amqpClient = AMQPClient(
targets,
allowedRemoteLegalNames,
amqpConfig,
AMQPClient.NettyThreading.Shared(sharedEventLoopGroup!!, sslDelegatedTaskExecutor!!)
)
private var session: ClientSession? = null
private var consumer: ClientConsumer? = null
private var connectedSubscription: Subscription? = null
@Volatile
private var messagesReceived: Boolean = false
private val eventLoop: EventLoop = sharedEventGroup.next()
private val eventLoop: EventLoop = sharedEventLoopGroup!!.next()
private var artemisState: ArtemisState = ArtemisState.STOPPED
set(value) {
logDebugWithMDC { "State change $field to $value" }
@ -152,32 +154,9 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private var scheduledExecutorService: ScheduledExecutorService
= Executors.newSingleThreadScheduledExecutor(ThreadFactoryBuilder().setNameFormat("bridge-connection-reset-%d").build())
@Suppress("ClassNaming")
private sealed class ArtemisState {
object STARTING : ArtemisState()
data class STARTED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
object CHECKING : ArtemisState()
object RESTARTED : ArtemisState()
object RECEIVING : ArtemisState()
object AMQP_STOPPED : ArtemisState()
object AMQP_STARTING : ArtemisState()
object AMQP_STARTED : ArtemisState()
object AMQP_RESTARTED : ArtemisState()
object STOPPING : ArtemisState()
object STOPPED : ArtemisState()
data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
open val pending: ScheduledFuture<Unit>? = null
override fun toString(): String = javaClass.simpleName
}
private fun artemis(inProgress: ArtemisState, block: (precedingState: ArtemisState) -> ArtemisState) {
val runnable = {
synchronized(artemis) {
synchronized(artemis!!) {
try {
val precedingState = artemisState
artemisState.pending?.cancel(false)
@ -231,7 +210,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
}
ArtemisState.STOPPING
}
bridgeMetricsService?.bridgeDisconnected(targets, legalNames)
bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
connectedSubscription?.unsubscribe()
connectedSubscription = null
// Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first.
@ -243,7 +222,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
if (connected) {
logInfoWithMDC("Bridge Connected")
bridgeMetricsService?.bridgeConnected(targets, legalNames)
bridgeMetricsService?.bridgeConnected(targets, allowedRemoteLegalNames)
if (bridgeConnectionTTLSeconds > 0) {
// AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval
amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS,
@ -253,7 +232,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
}
}
artemis(ArtemisState.STARTING) {
val startedArtemis = artemis.started
val startedArtemis = artemis!!.started
if (startedArtemis == null) {
logInfoWithMDC("Bridge Connected but Artemis is disconnected")
ArtemisState.STOPPED
@ -286,7 +265,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
logInfoWithMDC("Bridge Disconnected")
amqpRestartEvent?.cancel(false)
if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) {
bridgeMetricsService?.bridgeDisconnected(targets, legalNames)
bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
}
artemis(ArtemisState.STOPPING) { precedingState: ArtemisState ->
logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected")
@ -418,10 +397,10 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
properties[key] = value
}
}
logDebugWithMDC { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
logDebugWithMDC { "Bridged Send to ${allowedRemoteLegalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
val peerInbox = translateLocalQueueToInboxAddress(queueName)
val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox,
legalNames.first().toString(),
allowedRemoteLegalNames.first().toString(),
properties)
sendableMessage.onComplete.then {
logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" }
@ -457,6 +436,29 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
}
}
@Suppress("ClassNaming")
private sealed class ArtemisState {
object STARTING : ArtemisState()
data class STARTED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
object CHECKING : ArtemisState()
object RESTARTED : ArtemisState()
object RECEIVING : ArtemisState()
object AMQP_STOPPED : ArtemisState()
object AMQP_STARTING : ArtemisState()
object AMQP_STARTED : ArtemisState()
object AMQP_RESTARTED : ArtemisState()
object STOPPING : ArtemisState()
object STOPPED : ArtemisState()
data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
open val pending: ScheduledFuture<Unit>? = null
override fun toString(): String = javaClass.simpleName
}
override fun deployBridge(sourceX500Name: String, queueName: String, targets: List<NetworkHostAndPort>, legalNames: Set<CordaX500Name>) {
lock.withLock {
val bridges = queueNamesToBridgesMap.getOrPut(queueName) { mutableListOf() }
@ -467,8 +469,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
}
val newAMQPConfig = with(amqpConfig) { AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize,
revocationConfig, useOpenSsl, enableSNI, sourceX500Name, trace, sslHandshakeTimeout) }
val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig, sharedEventLoopGroup!!, artemis!!,
bridgeMetricsService, bridgeConnectionTTLSeconds)
val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig)
bridges += newBridge
bridgeMetricsService?.bridgeCreated(targets, legalNames)
newBridge
@ -486,7 +487,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
queueNamesToBridgesMap.remove(queueName)
}
bridge.stop()
bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.legalNames)
bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.allowedRemoteLegalNames)
}
}
}
@ -497,15 +498,16 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
// queueNamesToBridgesMap returns a mutable list, .toList converts it to a immutable list so it won't be changed by the [destroyBridge] method.
val bridges = queueNamesToBridgesMap[queueName]?.toList()
destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList())
bridges?.map {
it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.legalNames.toList(), serviceAddress = false)
}?.toMap() ?: emptyMap()
bridges?.associate {
it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.allowedRemoteLegalNames.toList(), serviceAddress = false)
} ?: emptyMap()
}
}
override fun start() {
sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS, DefaultThreadFactory("AMQPBridge", Thread.MAX_PRIORITY))
val artemis = artemisMessageClientFactory()
sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS, DefaultThreadFactory("NettyBridge", Thread.MAX_PRIORITY))
sslDelegatedTaskExecutor = sslDelegatedTaskExecutor("NettyBridge")
val artemis = artemisMessageClientFactory("ArtemisBridge")
this.artemis = artemis
artemis.start()
}
@ -522,6 +524,8 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
sharedEventLoopGroup = null
queueNamesToBridgesMap.clear()
artemis?.stop()
sslDelegatedTaskExecutor?.shutdown()
sslDelegatedTaskExecutor = null
}
}
}

View File

@ -35,7 +35,7 @@ class BridgeControlListener(private val keyStore: CertificateStore,
maxMessageSize: Int,
revocationConfig: RevocationConfig,
enableSNI: Boolean,
private val artemisMessageClientFactory: () -> ArtemisSessionProvider,
private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider,
bridgeMetricsService: BridgeMetricsService? = null,
trace: Boolean = false,
sslHandshakeTimeout: Duration? = null,
@ -80,7 +80,7 @@ class BridgeControlListener(private val keyStore: CertificateStore,
bridgeNotifyQueue = "$BRIDGE_NOTIFY.$queueDisambiguityId"
bridgeManager.start()
val artemis = artemisMessageClientFactory()
val artemis = artemisMessageClientFactory("BridgeControl")
this.artemis = artemis
artemis.start()
val artemisClient = artemis.started!!

View File

@ -37,7 +37,7 @@ class LoopbackBridgeManager(keyStore: CertificateStore,
maxMessageSize: Int,
revocationConfig: RevocationConfig,
enableSNI: Boolean,
private val artemisMessageClientFactory: () -> ArtemisSessionProvider,
private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider,
private val bridgeMetricsService: BridgeMetricsService? = null,
private val isLocalInbox: (String) -> Boolean,
trace: Boolean,
@ -204,7 +204,7 @@ class LoopbackBridgeManager(keyStore: CertificateStore,
override fun start() {
super.start()
val artemis = artemisMessageClientFactory()
val artemis = artemisMessageClientFactory("LoopbackBridge")
this.artemis = artemis
artemis.start()
}

View File

@ -5,16 +5,37 @@ package net.corda.nodeapi.internal.crypto
import net.corda.core.CordaOID
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.newSecureRandom
import net.corda.core.internal.*
import net.corda.core.internal.CertRole
import net.corda.core.internal.SignedDataWithCert
import net.corda.core.internal.reader
import net.corda.core.internal.signWithCert
import net.corda.core.internal.uncheckedCast
import net.corda.core.internal.validate
import net.corda.core.internal.writer
import net.corda.core.utilities.days
import net.corda.core.utilities.millis
import net.corda.core.utilities.toHex
import net.corda.nodeapi.internal.protonwrapper.netty.distributionPointsToString
import org.bouncycastle.asn1.*
import org.bouncycastle.asn1.ASN1EncodableVector
import org.bouncycastle.asn1.ASN1ObjectIdentifier
import org.bouncycastle.asn1.ASN1Sequence
import org.bouncycastle.asn1.DERSequence
import org.bouncycastle.asn1.DERUTF8String
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x509.*
import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.DistributionPoint
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.KeyPurposeId
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.asn1.x509.NameConstraints
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.X509v3CertificateBuilder
import org.bouncycastle.cert.bc.BcX509ExtensionUtils
@ -32,8 +53,13 @@ import java.nio.file.Path
import java.security.KeyPair
import java.security.PublicKey
import java.security.SignatureException
import java.security.cert.*
import java.security.cert.CertPath
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.CertificateFactory
import java.security.cert.TrustAnchor
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.time.Duration
import java.time.Instant
import java.time.temporal.ChronoUnit
@ -359,7 +385,7 @@ object X509Utilities {
private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) {
if (crlDistPoint != null) {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint)))
val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier))
val crlIssuerGeneralNames = crlIssuer?.let {
GeneralNames(GeneralName(crlIssuer))
}
@ -379,6 +405,8 @@ object X509Utilities {
bytes[0] = bytes[0].and(0x3F).or(0x40)
return BigInteger(bytes)
}
fun toGeneralNames(string: String, tag: Int = GeneralName.directoryName): GeneralNames = GeneralNames(GeneralName(tag, string))
}
// Assuming cert type to role is 1:1

View File

@ -27,16 +27,17 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME
import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import rx.Observable
import rx.subjects.PublishSubject
import java.lang.Long.min
import java.net.InetSocketAddress
import java.security.cert.CertPathValidatorException
import java.util.concurrent.Executor
import java.util.concurrent.ExecutorService
import java.util.concurrent.ThreadPoolExecutor
import java.time.Duration
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock
enum class ProxyVersion {
@ -63,8 +64,8 @@ data class ProxyConfig(val version: ProxyVersion, val proxyAddress: NetworkHostA
class AMQPClient(private val targets: List<NetworkHostAndPort>,
val allowedRemoteLegalNames: Set<CordaX500Name>,
private val configuration: AMQPConfiguration,
private val sharedThreadPool: EventLoopGroup? = null,
private val threadPoolName: String = "AMQPClient") : AutoCloseable {
private val nettyThreading: NettyThreading = NettyThreading.NonShared("AMQPClient"),
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON) : AutoCloseable {
companion object {
init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
@ -84,14 +85,12 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
private val lock = ReentrantLock()
@Volatile
private var started: Boolean = false
private var workerGroup: EventLoopGroup? = null
@Volatile
private var clientChannel: Channel? = null
// Offset into the list of targets, so that we can implement round-robin reconnect logic.
private var targetIndex = 0
private var currentTarget: NetworkHostAndPort = targets.first()
private var retryInterval = MIN_RETRY_INTERVAL
private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker()
private val handshakeFailureRetryTargets = mutableSetOf<NetworkHostAndPort>()
private var retryingHandshakeFailures = false
private var retryOffset = 0
@ -172,7 +171,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
log.info("Failed to connect to $currentTarget", future.cause())
if (started) {
workerGroup?.schedule({
nettyThreading.eventLoopGroup.schedule({
nextTarget()
restart()
}, retryInterval, TimeUnit.MILLISECONDS)
@ -191,7 +190,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
clientChannel = null
if (started && !amqpActive) {
log.debug { "Scheduling restart of $currentTarget (AMQP inactive)" }
workerGroup?.schedule({
nettyThreading.eventLoopGroup.schedule({
nextTarget()
restart()
}, retryInterval, TimeUnit.MILLISECONDS)
@ -199,17 +198,16 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
}
private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration
@Volatile
private lateinit var amqpChannelHandler: AMQPChannelHandler
init {
keyManagerFactory.init(conf.keyStore)
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker))
}
@Suppress("ComplexMethod")
override fun initChannel(ch: SocketChannel) {
val pipeline = ch.pipeline()
@ -249,9 +247,22 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration)
val target = parent.currentTarget
val handler = if (parent.configuration.useOpenSsl) {
createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc())
createClientOpenSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
ch.alloc(),
parent.nettyThreading.sslDelegatedTaskExecutor
)
} else {
createClientSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory)
createClientSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
parent.nettyThreading.sslDelegatedTaskExecutor
)
}
handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis()
pipeline.addLast("sslHandler", handler)
@ -292,7 +303,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
if (started && amqpActive) {
log.debug { "Scheduling restart of $currentTarget (AMQP active)" }
workerGroup?.schedule({
nettyThreading.eventLoopGroup.schedule({
nextTarget()
restart()
}, retryInterval, TimeUnit.MILLISECONDS)
@ -309,7 +320,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
return
}
log.info("Connect to: $currentTarget")
workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY))
(nettyThreading as? NettyThreading.NonShared)?.start()
started = true
restart()
}
@ -321,7 +332,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
}
val bootstrap = Bootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
bootstrap.group(workerGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this))
bootstrap.group(nettyThreading.eventLoopGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this))
// Delegate DNS Resolution to the proxy side, if we are using proxy.
if (configuration.proxyConfig != null) {
bootstrap.resolver(NoopAddressResolverGroup.INSTANCE)
@ -335,14 +346,12 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
lock.withLock {
log.info("Stopping connection to: $currentTarget, Local address: $localAddressString")
started = false
if (sharedThreadPool == null) {
workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync()
if (nettyThreading is NettyThreading.NonShared) {
nettyThreading.stop()
} else {
clientChannel?.close()?.sync()
}
clientChannel = null
workerGroup = null
log.info("Stopped connection to $currentTarget")
}
}
@ -384,5 +393,35 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
val onConnection: Observable<ConnectionChange>
get() = _onConnection
val softFailExceptions: List<CertPathValidatorException> get() = revocationChecker.softFailExceptions
}
sealed class NettyThreading {
abstract val eventLoopGroup: EventLoopGroup
abstract val sslDelegatedTaskExecutor: Executor
class Shared(override val eventLoopGroup: EventLoopGroup,
override val sslDelegatedTaskExecutor: ExecutorService = sslDelegatedTaskExecutor("AMQPClient")) : NettyThreading()
class NonShared(val threadPoolName: String) : NettyThreading() {
private var _eventLoopGroup: NioEventLoopGroup? = null
override val eventLoopGroup: EventLoopGroup get() = checkNotNull(_eventLoopGroup)
private var _sslDelegatedTaskExecutor: ThreadPoolExecutor? = null
override val sslDelegatedTaskExecutor: ExecutorService get() = checkNotNull(_sslDelegatedTaskExecutor)
fun start() {
check(_eventLoopGroup == null)
check(_sslDelegatedTaskExecutor == null)
_eventLoopGroup = NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY))
_sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
}
fun stop() {
eventLoopGroup.shutdownGracefully()
eventLoopGroup.terminationFuture().sync()
sslDelegatedTaskExecutor.shutdown()
_eventLoopGroup = null
_sslDelegatedTaskExecutor = null
}
}
}
}

View File

@ -21,16 +21,15 @@ import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import org.apache.qpid.proton.engine.Delivery
import rx.Observable
import rx.subjects.PublishSubject
import java.net.BindException
import java.net.InetSocketAddress
import java.security.cert.CertPathValidatorException
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock
/**
@ -39,37 +38,34 @@ import kotlin.concurrent.withLock
class AMQPServer(val hostName: String,
val port: Int,
private val configuration: AMQPConfiguration,
private val threadPoolName: String = "AMQPServer") : AutoCloseable {
private val threadPoolName: String = "AMQPServer",
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON,
private val remotingThreads: Int? = null) : AutoCloseable {
companion object {
init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
}
private const val CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME = "net.corda.nodeapi.amqpserver.NumServerThreads"
private val log = contextLogger()
private val NUM_SERVER_THREADS = Integer.getInteger(CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME, 4)
private val DEFAULT_REMOTING_THREADS = Integer.getInteger("net.corda.nodeapi.amqpserver.NumServerThreads", 4)
}
private val lock = ReentrantLock()
@Volatile
private var stopping: Boolean = false
private var bossGroup: EventLoopGroup? = null
private var workerGroup: EventLoopGroup? = null
private var serverChannel: Channel? = null
private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker()
private var sslDelegatedTaskExecutor: ExecutorService? = null
private val clientChannels = ConcurrentHashMap<InetSocketAddress, SocketChannel>()
private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration
init {
keyManagerFactory.init(conf.keyStore.value.internal, conf.keyStore.entryPassword.toCharArray())
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker))
}
override fun initChannel(ch: SocketChannel) {
val amqpConfiguration = parent.configuration
val pipeline = ch.pipeline()
@ -116,11 +112,12 @@ class AMQPServer(val hostName: String,
Pair(createServerSNIOpenSniHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap)
} else {
val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig)
val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor)
val handler = if (amqpConfig.useOpenSsl) {
createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc())
createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc(), delegatedTaskExecutor)
} else {
// For javaSSL, SNI matching is handled at key manager level.
createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory)
createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory, delegatedTaskExecutor)
}
handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout.toMillis()
Pair(handler, mapOf(DEFAULT to keyManagerFactory))
@ -132,8 +129,13 @@ class AMQPServer(val hostName: String,
lock.withLock {
stop()
sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("$threadPoolName-boss", Thread.MAX_PRIORITY))
workerGroup = NioEventLoopGroup(NUM_SERVER_THREADS, DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY))
workerGroup = NioEventLoopGroup(
remotingThreads ?: DEFAULT_REMOTING_THREADS,
DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY)
)
val server = ServerBootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
@ -154,22 +156,19 @@ class AMQPServer(val hostName: String,
fun stop() {
lock.withLock {
try {
stopping = true
serverChannel?.apply { close() }
serverChannel = null
serverChannel?.close()
serverChannel = null
workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync()
workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync()
workerGroup = null
bossGroup?.shutdownGracefully()
bossGroup?.terminationFuture()?.sync()
bossGroup?.shutdownGracefully()
bossGroup?.terminationFuture()?.sync()
bossGroup = null
workerGroup = null
bossGroup = null
} finally {
stopping = false
}
sslDelegatedTaskExecutor?.shutdown()
sslDelegatedTaskExecutor = null
}
}
@ -226,6 +225,4 @@ class AMQPServer(val hostName: String,
private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized()
val onConnection: Observable<ConnectionChange>
get() = _onConnection
val softFailExceptions: List<CertPathValidatorException> get() = revocationChecker.softFailExceptions
}

View File

@ -31,4 +31,6 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() {
override fun getSoftFailExceptions(): List<CertPathValidatorException> {
return Collections.emptyList()
}
override fun clone(): AllowAllRevocationChecker = this
}

View File

@ -3,9 +3,6 @@ package net.corda.nodeapi.internal.protonwrapper.netty
import com.typesafe.config.Config
import net.corda.nodeapi.internal.config.ConfigParser
import net.corda.nodeapi.internal.config.CustomConfigParser
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import java.security.cert.PKIXRevocationChecker
/**
* Data structure for controlling the way how Certificate Revocation Lists are handled.
@ -45,18 +42,6 @@ interface RevocationConfig {
* Optional [CrlSource] which only makes sense with `mode` = `EXTERNAL_SOURCE`
*/
val externalCrlSource: CrlSource?
fun createPKIXRevocationChecker(): PKIXRevocationChecker {
return when (mode) {
Mode.OFF -> AllowAllRevocationChecker
Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(externalCrlSource) { "externalCrlSource must be specfied for EXTERNAL_SOURCE" }
CordaRevocationChecker(externalCrlSource, softFail = true)
}
Mode.SOFT_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = true)
Mode.HARD_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = false)
}
}
}
/**

View File

@ -1,3 +1,5 @@
@file:Suppress("ComplexMethod", "LongParameterList")
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.buffer.ByteBufAllocator
@ -18,6 +20,8 @@ import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.namedThreadPoolExecutor
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import org.bouncycastle.asn1.ASN1InputStream
import org.bouncycastle.asn1.ASN1Primitive
import org.bouncycastle.asn1.ASN1IA5String
@ -34,10 +38,10 @@ import java.net.URI
import java.security.KeyStore
import java.security.cert.CertificateException
import java.security.cert.PKIXBuilderParameters
import java.security.cert.PKIXRevocationChecker
import java.security.cert.X509CertSelector
import java.security.cert.X509Certificate
import java.util.concurrent.Executor
import java.util.concurrent.ThreadPoolExecutor
import javax.net.ssl.CertPathTrustManagerParameters
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName
@ -46,7 +50,6 @@ import javax.net.ssl.SSLEngine
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509ExtendedTrustManager
import javax.security.auth.x500.X500Principal
import kotlin.system.measureTimeMillis
private const val HOSTNAME_FORMAT = "%s.corda.net"
internal const val DEFAULT = "default"
@ -58,7 +61,6 @@ internal val logger = LoggerFactory.getLogger("net.corda.nodeapi.internal.proton
/**
* Returns all the CRL distribution points in the certificate as [URI]s along with the CRL issuer names, if any.
*/
@Suppress("ComplexMethod")
fun X509Certificate.distributionPoints(): Map<URI, List<X500Principal>?> {
logger.debug { "Checking CRLDPs for $subjectX500Principal" }
@ -117,6 +119,14 @@ fun certPathToString(certPath: Array<out X509Certificate>?): String {
return certPath.joinToString(System.lineSeparator()) { " ${it.toSimpleString()}" }
}
/**
* Create an executor for processing SSL handshake tasks asynchronously (see [SSLEngine.getDelegatedTask]). The max number of threads is 3,
* which is the typical number of CRLs expected in a Corda TLS cert path. The executor needs to be passed to the [SslHandler] constructor.
*/
fun sslDelegatedTaskExecutor(parentPoolName: String): ThreadPoolExecutor {
return namedThreadPoolExecutor(maxPoolSize = 3, poolName = "$parentPoolName-ssltask")
}
@VisibleForTesting
class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() {
companion object {
@ -179,32 +189,11 @@ class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509Ex
}
private object LoggingImmediateExecutor : Executor {
override fun execute(command: Runnable) {
val log = LoggerFactory.getLogger(javaClass)
@Suppress("TooGenericExceptionCaught", "MagicNumber") // log and rethrow all exceptions
try {
val commandName = command::class.qualifiedName?.let { "[$it]" } ?: ""
log.debug("Entering SSL command $commandName")
val elapsedTime = measureTimeMillis { command.run() }
log.debug("Exiting SSL command $elapsedTime millis")
if (elapsedTime > 100) {
log.info("Command: $commandName took $elapsedTime millis to execute")
}
}
catch (ex: Exception) {
log.error("Caught exception in SSL handler executor", ex)
throw ex
}
}
}
internal fun createClientSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine(target.host, target.port)
sslEngine.useClientMode = true
@ -216,14 +205,15 @@ internal fun createClientSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters
}
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler {
alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build()
val sslEngine = sslContext.newEngine(alloc, target.host, target.port)
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
@ -233,12 +223,13 @@ internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters
}
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
internal fun createServerSslHandler(keyStore: CertificateStore,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine()
sslEngine.useClientMode = false
@ -249,39 +240,29 @@ internal fun createServerSslHandler(keyStore: CertificateStore,
val sslParameters = sslEngine.sslParameters
sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore))
sslEngine.sslParameters = sslParameters
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler {
alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build()
val sslEngine = sslContext.newEngine(alloc)
sslEngine.useClientMode = false
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SSLContext {
fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory?): SSLContext {
val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java)
.map { LoggingTrustManagerWrapper(it) }.toTypedArray()
sslContext.init(keyManagers, trustManagers, newSecureRandom())
val trustManagers = trustManagerFactory
?.trustManagers
?.map { if (it is X509ExtendedTrustManager) LoggingTrustManagerWrapper(it) else it }
?.toTypedArray()
sslContext.init(keyManagerFactory.keyManagers, trustManagers, newSecureRandom())
return sslContext
}
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore,
revocationConfig: RevocationConfig): CertPathTrustManagerParameters {
return initialiseTrustStoreAndEnableCrlChecking(trustStore, revocationConfig.createPKIXRevocationChecker())
}
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore,
revocationChecker: PKIXRevocationChecker): CertPathTrustManagerParameters {
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
return CertPathTrustManagerParameters(pkixParams)
}
/**
* Creates a special SNI handler used only when openSSL is used for AMQPServer
*/
@ -296,14 +277,13 @@ internal fun createServerSNIOpenSniHandler(keyManagerFactoriesMap: Map<String, K
return SniHandler(mapping.build())
}
@Suppress("SpreadOperator")
private fun getServerSslContextBuilder(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslContextBuilder {
return SslContextBuilder.forServer(keyManagerFactory)
.sslProvider(SslProvider.OPENSSL)
.trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory))
.clientAuth(ClientAuth.REQUIRE)
.ciphers(ArtemisTcpTransport.CIPHER_SUITES)
.protocols(*ArtemisTcpTransport.TLS_VERSIONS.toTypedArray())
.protocols(ArtemisTcpTransport.TLS_VERSIONS)
}
internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKeyManagerFactoryWrapper> {
@ -327,7 +307,38 @@ internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKe
// 2nd parameter `password` - the password for recovering keys in the KeyStore
fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.entryPassword.toCharArray())
fun TrustManagerFactory.init(trustStore: CertificateStore) = init(trustStore.value.internal)
fun keyManagerFactory(keyStore: CertificateStore): KeyManagerFactory {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
return keyManagerFactory
}
fun trustManagerFactory(trustStore: CertificateStore): TrustManagerFactory {
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore.value.internal)
return trustManagerFactory
}
fun trustManagerFactoryWithRevocation(trustStore: CertificateStore,
revocationConfig: RevocationConfig,
crlSource: CrlSource): TrustManagerFactory {
val revocationChecker = when (revocationConfig.mode) {
RevocationConfig.Mode.OFF -> AllowAllRevocationChecker
RevocationConfig.Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(revocationConfig.externalCrlSource) {
"externalCrlSource must be specfied for EXTERNAL_SOURCE"
}
CordaRevocationChecker(externalCrlSource, softFail = true)
}
RevocationConfig.Mode.SOFT_FAIL -> CordaRevocationChecker(crlSource, softFail = true)
RevocationConfig.Mode.HARD_FAIL -> CordaRevocationChecker(crlSource, softFail = false)
}
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(CertPathTrustManagerParameters(pkixParams))
return trustManagerFactory
}
/**
* Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target

View File

@ -5,6 +5,9 @@ import com.github.benmanes.caffeine.cache.LoadingCache
import net.corda.core.internal.readFully
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
@ -12,60 +15,71 @@ import net.corda.nodeapi.internal.protonwrapper.netty.distributionPoints
import java.net.URI
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit
import java.time.Duration
import javax.security.auth.x500.X500Principal
/**
* [CrlSource] which downloads CRLs from the distribution points in the X509 certificate.
* [CrlSource] which downloads CRLs from the distribution points in the X509 certificate and caches them.
*/
@Suppress("TooGenericExceptionCaught")
class CertDistPointCrlSource : CrlSource {
class CertDistPointCrlSource(cacheSize: Long = DEFAULT_CACHE_SIZE,
cacheExpiry: Duration = DEFAULT_CACHE_EXPIRY,
private val connectTimeout: Duration = DEFAULT_CONNECT_TIMEOUT,
private val readTimeout: Duration = DEFAULT_READ_TIMEOUT) : CrlSource {
companion object {
private val logger = contextLogger()
// The default SSL handshake timeout is 60s (DEFAULT_SSL_HANDSHAKE_TIMEOUT). Considering there are 3 CRLs endpoints to check in a
// node handshake, we want to keep the total timeout within that.
private const val DEFAULT_CONNECT_TIMEOUT = 9_000
private const val DEFAULT_READ_TIMEOUT = 9_000
private val DEFAULT_CONNECT_TIMEOUT = 9.seconds
private val DEFAULT_READ_TIMEOUT = 9.seconds
private const val DEFAULT_CACHE_SIZE = 185L // Same default as the JDK (URICertStore)
private const val DEFAULT_CACHE_EXPIRY = 5 * 60 * 1000L
private val DEFAULT_CACHE_EXPIRY = 5.minutes
private val cache: LoadingCache<URI, X509CRL> = Caffeine.newBuilder()
.maximumSize(java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE))
.expireAfterWrite(java.lang.Long.getLong("net.corda.dpcrl.cache.expiry", DEFAULT_CACHE_EXPIRY), TimeUnit.MILLISECONDS)
.build(::retrieveCRL)
val SINGLETON = CertDistPointCrlSource(
cacheSize = java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE),
cacheExpiry = java.lang.Long.getLong("net.corda.dpcrl.cache.expiry")?.let(Duration::ofMillis) ?: DEFAULT_CACHE_EXPIRY,
connectTimeout = java.lang.Long.getLong("net.corda.dpcrl.connect.timeout")?.let(Duration::ofMillis) ?: DEFAULT_CONNECT_TIMEOUT,
readTimeout = java.lang.Long.getLong("net.corda.dpcrl.read.timeout")?.let(Duration::ofMillis) ?: DEFAULT_READ_TIMEOUT
)
}
private val connectTimeout = Integer.getInteger("net.corda.dpcrl.connect.timeout", DEFAULT_CONNECT_TIMEOUT)
private val readTimeout = Integer.getInteger("net.corda.dpcrl.read.timeout", DEFAULT_READ_TIMEOUT)
private val cache: LoadingCache<URI, X509CRL> = Caffeine.newBuilder()
.maximumSize(cacheSize)
.expireAfterWrite(cacheExpiry)
.build(::retrieveCRL)
private fun retrieveCRL(uri: URI): X509CRL {
val start = System.currentTimeMillis()
val bytes = try {
val conn = uri.toURL().openConnection()
conn.connectTimeout = connectTimeout
conn.readTimeout = readTimeout
// Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes
// in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException
// into CRLException and drops the cause chain.
conn.getInputStream().readFully()
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e)
}
throw e
private fun retrieveCRL(uri: URI): X509CRL {
val start = System.currentTimeMillis()
val bytes = try {
val conn = uri.toURL().openConnection()
conn.connectTimeout = connectTimeout.toMillis().toInt()
conn.readTimeout = readTimeout.toMillis().toInt()
// Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes
// in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException
// into CRLException and drops the cause chain.
conn.getInputStream().readFully()
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e)
}
val duration = System.currentTimeMillis() - start
val crl = try {
X509CertificateFactory().generateCRL(bytes.inputStream())
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Invalid CRL from $uri (${duration}ms)", e)
}
throw e
}
logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" }
return crl
throw e
}
val duration = System.currentTimeMillis() - start
val crl = try {
X509CertificateFactory().generateCRL(bytes.inputStream())
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Invalid CRL from $uri (${duration}ms)", e)
}
throw e
}
logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" }
return crl
}
fun clearCache() {
cache.invalidateAll()
}
override fun fetch(certificate: X509Certificate): Set<X509CRL> {

View File

@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.assertj.core.api.Assertions
import org.junit.Rule
import org.junit.Test
@ -161,11 +162,9 @@ class TlsDiffAlgorithmsTest(private val serverAlgo: String, private val clientAl
private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext {
return SSLContext.getInstance("TLS").apply {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagerFactory = keyManagerFactory(keyStore)
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustMgrFactory = trustManagerFactory(trustStore)
val trustManagers = trustMgrFactory.trustManagers
init(keyManagers, trustManagers, newSecureRandom())
}

View File

@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.assertj.core.api.Assertions
import org.junit.Ignore
import org.junit.Rule
@ -18,7 +19,6 @@ import java.io.IOException
import java.net.InetAddress
import java.net.InetSocketAddress
import javax.net.ssl.*
import javax.net.ssl.SNIHostName
import kotlin.concurrent.thread
import kotlin.test.assertEquals
import kotlin.test.assertFalse
@ -209,11 +209,9 @@ class TlsDiffProtocolsTest(private val serverAlgo: String, private val clientAlg
private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext {
return SSLContext.getInstance("TLS").apply {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagerFactory = keyManagerFactory(keyStore)
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustMgrFactory = trustManagerFactory(trustStore)
val trustManagers = trustMgrFactory.trustManagers
init(keyManagers, trustManagers, newSecureRandom())
}

View File

@ -1,5 +1,6 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.util.concurrent.ImmediateExecutor
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.NetworkHostAndPort
@ -8,10 +9,9 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
import net.corda.testing.internal.fixedCrlSource
import org.junit.Test
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName
import javax.net.ssl.TrustManagerFactory
import kotlin.test.assertEquals
class SSLHelperTest {
@ -20,15 +20,21 @@ class SSLHelperTest {
val legalName = CordaX500Name("Test", "London", "GB")
val sslConfig = configureTestSSL(legalName)
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val keyStore = sslConfig.keyStore
keyManagerFactory.init(CertificateStore.fromFile(keyStore.path, keyStore.storePassword, keyStore.entryPassword, false))
val trustStore = sslConfig.trustStore
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(CertificateStore.fromFile(trustStore.path, trustStore.storePassword, trustStore.entryPassword, false), RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL)))
val trustManagerFactory = trustManagerFactoryWithRevocation(
sslConfig.trustStore.get(),
RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL),
fixedCrlSource(emptySet())
)
val sslHandler = createClientSslHandler(NetworkHostAndPort("localhost", 1234), setOf(legalName), keyManagerFactory, trustManagerFactory)
val sslHandler = createClientSslHandler(
NetworkHostAndPort("localhost", 1234),
setOf(legalName),
keyManagerFactory,
trustManagerFactory,
ImmediateExecutor.INSTANCE
)
val legalNameHash = SecureHash.sha256(legalName.toString()).toString().take(32).toLowerCase()
// These hardcoded values must not be changed, something is broken if you have to change these hardcoded values.

View File

@ -2,15 +2,13 @@ package net.corda.nodeapi.internal.revocation
import net.corda.core.crypto.Crypto
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.testing.core.ALICE_NAME
import net.corda.nodeapi.internal.DEV_INTERMEDIATE_CA
import net.corda.testing.node.internal.network.CrlServer
import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.math.BigInteger
class CertDistPointCrlSourceTest {
private lateinit var crlServer: CrlServer
@ -39,13 +37,14 @@ class CertDistPointCrlSourceTest {
assertThat(single().revokedCertificates).isNull()
}
val nodeCaCert = crlServer.replaceNodeCertDistPoint(createDevNodeCa(crlServer.intermediateCa, ALICE_NAME).certificate)
crlSource.clearCache()
crlServer.revokedNodeCerts += listOf(BigInteger.ONE, BigInteger.TEN)
with(crlSource.fetch(nodeCaCert)) { // Use a different cert to avoid the cache
crlServer.revokedIntermediateCerts += DEV_INTERMEDIATE_CA.certificate
with(crlSource.fetch(crlServer.intermediateCa.certificate)) {
assertThat(size).isEqualTo(1)
val revokedCertificates = single().revokedCertificates
assertThat(revokedCertificates.map { it.serialNumber }).containsExactlyInAnyOrder(BigInteger.ONE, BigInteger.TEN)
// This also tests clearCache() works.
assertThat(revokedCertificates.map { it.serialNumber }).containsExactly(DEV_INTERMEDIATE_CA.certificate.serialNumber)
}
}
}

View File

@ -5,7 +5,7 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
import net.corda.testing.internal.fixedCrlSource
import org.bouncycastle.jcajce.provider.asymmetric.x509.CertificateFactory
import org.junit.Test
import java.math.BigInteger
@ -41,10 +41,8 @@ class CordaRevocationCheckerTest {
val resourceAsStream = javaClass.getResourceAsStream("/net/corda/nodeapi/internal/protonwrapper/netty/doorman.crl")
val crl = CertificateFactory().engineGenerateCRL(resourceAsStream) as X509CRL
val crlSource = object : CrlSource {
override fun fetch(certificate: X509Certificate): Set<X509CRL> = setOf(crl)
}
val checker = CordaRevocationChecker(crlSource,
val checker = CordaRevocationChecker(
crlSource = fixedCrlSource(setOf(crl)),
softFail = true,
dateSource = { Date.from(date.atStartOfDay().toInstant(ZoneOffset.UTC)) }
)

View File

@ -0,0 +1,203 @@
package net.corda.nodeapi.internal.revocation
import net.corda.core.crypto.Crypto
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509KeyStore
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.testing.core.createCRL
import org.bouncycastle.asn1.x500.X500Name
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import java.io.File
import java.security.KeyPair
import java.security.KeyStore
import java.security.PrivateKey
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.*
import javax.net.ssl.X509TrustManager
import javax.security.auth.x500.X500Principal
import kotlin.test.assertFailsWith
@RunWith(Parameterized::class)
class RevocationTest(private val revocationMode: RevocationConfig.Mode) {
companion object {
@JvmStatic
@Parameterized.Parameters(name = "revocationMode = {0}")
fun data() = listOf(RevocationConfig.Mode.OFF, RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
}
@Rule
@JvmField
val tempFolder = TemporaryFolder()
private lateinit var rootCRL: File
private lateinit var doormanCRL: File
private lateinit var tlsCRL: File
private lateinit var trustManager: X509TrustManager
private val rootKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private val tlsCRLIssuerKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private val doormanKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private val nodeCAKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private val tlsKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private lateinit var rootCert: X509Certificate
private lateinit var tlsCRLIssuerCert: X509Certificate
private lateinit var doormanCert: X509Certificate
private lateinit var nodeCACert: X509Certificate
private lateinit var tlsCert: X509Certificate
private val chain
get() = arrayOf(tlsCert, nodeCACert, doormanCert, rootCert)
@Before
fun before() {
rootCRL = tempFolder.newFile("root.crl")
doormanCRL = tempFolder.newFile("doorman.crl")
tlsCRL = tempFolder.newFile("tls.crl")
rootCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=root"), rootKeyPair)
tlsCRLIssuerCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=issuer"), tlsCRLIssuerKeyPair)
val trustStore = KeyStore.getInstance("JKS")
trustStore.load(null, null)
trustStore.setCertificateEntry("cordatlscrlsigner", tlsCRLIssuerCert)
trustStore.setCertificateEntry("cordarootca", rootCert)
val trustManagerFactory = trustManagerFactoryWithRevocation(
CertificateStore.of(X509KeyStore(trustStore, "pass"), "pass", "pass"),
RevocationConfigImpl(revocationMode),
CertDistPointCrlSource()
)
trustManager = trustManagerFactory.trustManagers.single() as X509TrustManager
doormanCert = X509Utilities.createCertificate(
CertificateType.INTERMEDIATE_CA, rootCert, rootKeyPair, X500Principal("CN=doorman"), doormanKeyPair.public,
crlDistPoint = rootCRL.toURI().toString()
)
nodeCACert = X509Utilities.createCertificate(
CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=node"), nodeCAKeyPair.public,
crlDistPoint = doormanCRL.toURI().toString()
)
tlsCert = X509Utilities.createCertificate(
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=tls"), tlsKeyPair.public,
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
)
rootCRL.writeCRL(rootCert, rootKeyPair.private, false)
doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false)
tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true)
}
private fun File.writeCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) {
val crl = createCRL(
CertificateAndKeyPair(certificate, KeyPair(certificate.publicKey, privateKey)),
revoked.asList(),
indirect = indirect
)
writeBytes(crl.encoded)
}
private fun assertFailsFor(vararg modes: RevocationConfig.Mode) {
if (revocationMode in modes) assertFailsWith(CertificateException::class, ::doRevocationCheck) else doRevocationCheck()
}
@Test(timeout = 300_000)
fun `ok with empty CRLs`() {
doRevocationCheck()
}
@Test(timeout = 300_000)
fun `soft fail with revoked TLS certificate`() {
tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert)
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
fun `hard fail with unavailable CRL in TLS certificate`() {
tlsCert = X509Utilities.createCertificate(
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=tls"), tlsKeyPair.public,
crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/tls",
crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
fun `hard fail with invalid CRL issuer in TLS certificate`() {
tlsCert = X509Utilities.createCertificate(
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=tls"), tlsKeyPair.public,
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name("CN=unknown")
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
fun `hard fail without CRL issuer in TLS certificate`() {
tlsCert = X509Utilities.createCertificate(
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=tls"), tlsKeyPair.public,
crlDistPoint = tlsCRL.toURI().toString()
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
fun `ok with other certificate in TLS CRL`() {
val otherKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
val otherCert = X509Utilities.createCertificate(
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=other"), otherKeyPair.public,
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
)
tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert)
doRevocationCheck()
}
@Test(timeout = 300_000)
fun `soft fail with revoked node CA certificate`() {
doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, nodeCACert)
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
fun `hard fail with unavailable CRL in node CA certificate`() {
nodeCACert = X509Utilities.createCertificate(
CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=node"), nodeCAKeyPair.public,
crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/doorman"
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
fun `ok with other certificate in doorman CRL`() {
val otherKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
val otherCert = X509Utilities.createCertificate(
CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=other"), otherKeyPair.public,
crlDistPoint = doormanCRL.toURI().toString()
)
doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, otherCert)
doRevocationCheck()
}
private fun doRevocationCheck() {
trustManager.checkClientTrusted(chain, "ECDHE_ECDSA")
}
}