diff --git a/.snyk b/.snyk index 3970c56889..2fbffb7245 100755 --- a/.snyk +++ b/.snyk @@ -159,7 +159,7 @@ ignore: assessment. Liquibase is used to apply the database migration changes. XML files are used here to define the changes not YAML and therefore the Corda node itself is not exposed to this deserialisation - vulnerability. + vulnerability. expires: 2023-07-12T17:00:51.957Z created: 2022-12-29T17:00:51.970Z SNYK-JAVA-ORGYAML-3016889: @@ -180,7 +180,7 @@ ignore: - '*': reason: >- H2 console is not enabled for any of the applications we are running. - When it comes to DB connectivity parameters, we do not allow changing + When it comes to DB connectivity parameters, we do not allow changing them as they are supplied by Corda Node configuration file. expires: 2023-07-28T11:36:39.068Z created: 2022-12-29T11:36:39.089Z diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt index 12cc2c7d1a..0322df3385 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt @@ -1,5 +1,6 @@ package net.corda.client.rpc +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.ReconnectingCordaRPCOps import net.corda.client.rpc.internal.SerializationEnvironmentHelper @@ -52,7 +53,7 @@ class CordaRPCConnection private constructor( sslConfiguration: ClientRpcSslOptions? = null, classLoader: ClassLoader? = null ): CordaRPCConnection { - val observersPool: ExecutorService = Executors.newCachedThreadPool() + val observersPool: ExecutorService = Executors.newCachedThreadPool(DefaultThreadFactory("RPCObserver")) return CordaRPCConnection(null, observersPool, ReconnectingCordaRPCOps( addresses, username, diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/ReconnectingCordaRPCOps.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/ReconnectingCordaRPCOps.kt index 6f7df5cf84..fab8613139 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/ReconnectingCordaRPCOps.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/ReconnectingCordaRPCOps.kt @@ -1,5 +1,6 @@ package net.corda.client.rpc.internal +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.client.rpc.ConnectionFailureException import net.corda.client.rpc.CordaRPCClient import net.corda.client.rpc.CordaRPCClientConfiguration @@ -99,7 +100,8 @@ class ReconnectingCordaRPCOps private constructor( ErrorInterceptingHandler(reconnectingRPCConnection)) as CordaRPCOps } } - private val retryFlowsPool = Executors.newScheduledThreadPool(1) + private val retryFlowsPool = Executors.newScheduledThreadPool(1, DefaultThreadFactory("FlowRetry")) + /** * This function runs a flow and retries until it completes successfully. * diff --git a/core-tests/src/test/kotlin/net/corda/coretests/crypto/internal/ProviderMapTest.kt b/core-tests/src/test/kotlin/net/corda/coretests/crypto/internal/ProviderMapTest.kt new file mode 100644 index 0000000000..11a7e9ed85 --- /dev/null +++ b/core-tests/src/test/kotlin/net/corda/coretests/crypto/internal/ProviderMapTest.kt @@ -0,0 +1,29 @@ +package net.corda.coretests.crypto.internal + +import net.corda.coretesting.internal.DEV_ROOT_CA +import net.corda.testing.core.createCRL +import org.assertj.core.api.Assertions.assertThatIllegalArgumentException +import org.junit.Test + +class ProviderMapTest { + // https://github.com/corda/corda/pull/3997 + @Test(timeout = 300_000) + fun `verify CRL algorithms`() { + val crl = createCRL( + issuer = DEV_ROOT_CA, + revokedCerts = emptyList(), + signatureAlgorithm = "SHA256withECDSA" + ) + // This should pass. + crl.verify(DEV_ROOT_CA.keyPair.public) + + // Try changing the algorithm to EC will fail. + assertThatIllegalArgumentException().isThrownBy { + createCRL( + issuer = DEV_ROOT_CA, + revokedCerts = emptyList(), + signatureAlgorithm = "EC" + ) + }.withMessage("Unknown signature type requested: EC") + } +} diff --git a/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt b/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt index 4b3456d084..82c5ddea23 100644 --- a/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt +++ b/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt @@ -67,7 +67,7 @@ interface ServicesForResolution { /** * Given a [Set] of [StateRef]'s loads the referenced transaction and looks up the specified output [ContractState]. * - * @throws TransactionResolutionException if [stateRef] points to a non-existent transaction. + * @throws TransactionResolutionException if any of the [stateRefs] point to a non-existent transaction. */ // TODO: future implementation to use a Vault state ref -> contract state BLOB table and perform single query bulk load // as the existing transaction store will become encrypted at some point diff --git a/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto/X509UtilitiesTest.kt b/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto/X509UtilitiesTest.kt index 9ce80590a1..9affc6a0b1 100644 --- a/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto/X509UtilitiesTest.kt +++ b/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto/X509UtilitiesTest.kt @@ -21,14 +21,29 @@ import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.core.utilities.days import net.corda.core.utilities.hours -import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme +import net.corda.coretesting.internal.NettyTestClient +import net.corda.coretesting.internal.NettyTestHandler +import net.corda.coretesting.internal.NettyTestServer +import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.createDevNodeCa +import net.corda.nodeapi.internal.crypto.CertificateType +import net.corda.nodeapi.internal.crypto.X509CertificateFactory +import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_IDENTITY_SIGNATURE_SCHEME import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME +import net.corda.nodeapi.internal.crypto.checkValidity +import net.corda.nodeapi.internal.crypto.getSupportedKey +import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore +import net.corda.nodeapi.internal.crypto.save +import net.corda.nodeapi.internal.crypto.toBc +import net.corda.nodeapi.internal.crypto.x509 +import net.corda.nodeapi.internal.crypto.x509Certificates import net.corda.nodeapi.internal.installDevNodeCaCertPath -import net.corda.nodeapi.internal.protonwrapper.netty.init +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import net.corda.nodeapi.internal.registerDevP2pCertificates +import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme import net.corda.serialization.internal.AllWhitelist import net.corda.serialization.internal.SerializationContextImpl import net.corda.serialization.internal.SerializationFactoryImpl @@ -37,25 +52,16 @@ import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.core.TestIdentity import net.corda.testing.driver.internal.incrementalPortAllocation -import net.corda.coretesting.internal.NettyTestClient -import net.corda.coretesting.internal.NettyTestHandler -import net.corda.coretesting.internal.NettyTestServer -import net.corda.testing.internal.createDevIntermediateCaCertPath -import net.corda.coretesting.internal.stubs.CertificateStoreStubs -import net.corda.nodeapi.internal.crypto.CertificateType -import net.corda.nodeapi.internal.crypto.X509CertificateFactory -import net.corda.nodeapi.internal.crypto.X509Utilities -import net.corda.nodeapi.internal.crypto.checkValidity -import net.corda.nodeapi.internal.crypto.getSupportedKey -import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore -import net.corda.nodeapi.internal.crypto.save -import net.corda.nodeapi.internal.crypto.toBc -import net.corda.nodeapi.internal.crypto.x509 -import net.corda.nodeapi.internal.crypto.x509Certificates import net.corda.testing.internal.IS_OPENJ9 +import net.corda.testing.internal.createDevIntermediateCaCertPath import net.i2p.crypto.eddsa.EdDSAPrivateKey import org.assertj.core.api.Assertions.assertThat -import org.bouncycastle.asn1.x509.* +import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier +import org.bouncycastle.asn1.x509.BasicConstraints +import org.bouncycastle.asn1.x509.CRLDistPoint +import org.bouncycastle.asn1.x509.Extension +import org.bouncycastle.asn1.x509.KeyUsage +import org.bouncycastle.asn1.x509.SubjectKeyIdentifier import org.bouncycastle.jcajce.provider.asymmetric.edec.BCEdDSAPrivateKey import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey import org.junit.Assume @@ -74,10 +80,19 @@ import java.security.PrivateKey import java.security.cert.CertPath import java.security.cert.X509Certificate import java.util.* -import javax.net.ssl.* +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLParameters +import javax.net.ssl.SSLServerSocket +import javax.net.ssl.SSLSocket import javax.security.auth.x500.X500Principal import kotlin.concurrent.thread -import kotlin.test.* +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.test.fail class X509UtilitiesTest { private companion object { @@ -295,15 +310,10 @@ class X509UtilitiesTest { sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) sslConfig.createTrustStore(rootCa.certificate) - val keyStore = sslConfig.keyStore.get() - val trustStore = sslConfig.trustStore.get() - val context = SSLContext.getInstance("TLS") - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) + val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get()) val keyManagers = keyManagerFactory.keyManagers - val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get()) val trustManagers = trustMgrFactory.trustManagers context.init(keyManagers, trustManagers, newSecureRandom()) @@ -388,15 +398,8 @@ class X509UtilitiesTest { sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) sslConfig.createTrustStore(rootCa.certificate) - val keyStore = sslConfig.keyStore.get() - val trustStore = sslConfig.trustStore.get() - - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) - - val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustManagerFactory.init(trustStore) - + val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get()) + val trustManagerFactory = trustManagerFactory(sslConfig.trustStore.get()) val sslServerContext = SslContextBuilder .forServer(keyManagerFactory) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt index 7be0ac6229..1c914c35c4 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt @@ -42,8 +42,8 @@ class ArtemisMessagingClient(private val config: MutualSslConfiguration, override fun start(): Started = synchronized(this) { check(started == null) { "start can't be called twice" } val tcpTransport = p2pConnectorTcpTransport(serverAddress, config, threadPoolName = threadPoolName, trace = trace) - val backupTransports = backupServerAddressPool.map { - p2pConnectorTcpTransport(it, config, threadPoolName = threadPoolName, trace = trace) + val backupTransports = backupServerAddressPool.mapIndexed { index, address -> + p2pConnectorTcpTransport(address, config, threadPoolName = "$threadPoolName-backup${index+1}", trace = trace) } log.info("Connecting to message broker: $serverAddress") diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt index f46202cc07..6b5b353b3e 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt @@ -1,16 +1,18 @@ +@file:Suppress("LongParameterList") + package net.corda.nodeapi.internal import net.corda.core.messaging.ClientRpcSslOptions import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.BrokerRpcSslOptions -import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.SslConfiguration +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import org.apache.activemq.artemis.api.core.TransportConfiguration import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants -import java.nio.file.Path +import javax.net.ssl.TrustManagerFactory @Suppress("LongParameterList") class ArtemisTcpTransport { @@ -23,6 +25,7 @@ class ArtemisTcpTransport { val TLS_VERSIONS = listOf("TLSv1.2") const val SSL_HANDSHAKE_TIMEOUT_NAME = "Corda-SSLHandshakeTimeout" + const val TRUST_MANAGER_FACTORY_NAME = "Corda-TrustManagerFactory" const val TRACE_NAME = "Corda-Trace" const val THREAD_POOL_NAME_NAME = "Corda-ThreadPoolName" @@ -30,7 +33,6 @@ class ArtemisTcpTransport { // Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop. // It does not use AMQP messages for its own messages e.g. topology and heartbeats. private const val P2P_PROTOCOLS = "CORE,AMQP" - private const val RPC_PROTOCOLS = "CORE" private fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort, protocols: String) = mapOf( @@ -39,46 +41,35 @@ class ArtemisTcpTransport { TransportConstants.PORT_PROP_NAME to hostAndPort.port, TransportConstants.PROTOCOLS_PROP_NAME to protocols, TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null), - TransportConstants.REMOTING_THREADS_PROPNAME to (if (nodeSerializationEnv != null) -1 else 1), // turn off direct delivery in Artemis - this is latency optimisation that can lead to //hick-ups under high load (CORDA-1336) TransportConstants.DIRECT_DELIVER to false) - private val defaultSSLOptions = mapOf( - TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","), - TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to TLS_VERSIONS.joinToString(",")) - private fun SslConfiguration.addToTransportOptions(options: MutableMap) { + if (keyStore != null || trustStore != null) { + options[TransportConstants.SSL_ENABLED_PROP_NAME] = true + options[TransportConstants.NEED_CLIENT_AUTH_PROP_NAME] = true + } keyStore?.let { with (it) { path.requireOnDefaultFileSystem() - options.putAll(get().toKeyStoreTransportOptions(path)) + options[TransportConstants.KEYSTORE_TYPE_PROP_NAME] = "JKS" + options[TransportConstants.KEYSTORE_PATH_PROP_NAME] = path + options[TransportConstants.KEYSTORE_PASSWORD_PROP_NAME] = get().password } } trustStore?.let { with (it) { path.requireOnDefaultFileSystem() - options.putAll(get().toTrustStoreTransportOptions(path)) + options[TransportConstants.TRUSTSTORE_TYPE_PROP_NAME] = "JKS" + options[TransportConstants.TRUSTSTORE_PATH_PROP_NAME] = path + options[TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME] = get().password } } options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER options[SSL_HANDSHAKE_TIMEOUT_NAME] = handshakeTimeout ?: DEFAULT_SSL_HANDSHAKE_TIMEOUT } - private fun CertificateStore.toKeyStoreTransportOptions(path: Path) = mapOf( - TransportConstants.SSL_ENABLED_PROP_NAME to true, - TransportConstants.KEYSTORE_TYPE_PROP_NAME to "JKS", - TransportConstants.KEYSTORE_PATH_PROP_NAME to path, - TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to password, - TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true) - - private fun CertificateStore.toTrustStoreTransportOptions(path: Path) = mapOf( - TransportConstants.SSL_ENABLED_PROP_NAME to true, - TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to "JKS", - TransportConstants.TRUSTSTORE_PATH_PROP_NAME to path, - TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME to password, - TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true) - private fun ClientRpcSslOptions.toTransportOptions() = mapOf( TransportConstants.SSL_ENABLED_PROP_NAME to true, TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to trustStoreProvider, @@ -94,76 +85,110 @@ class ArtemisTcpTransport { fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, + trustManagerFactory: TrustManagerFactory? = config?.trustStore?.get()?.let(::trustManagerFactory), enableSSL: Boolean = true, threadPoolName: String = "P2PServer", - trace: Boolean = false): TransportConfiguration { + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { val options = mutableMapOf() if (enableSSL) { config?.addToTransportOptions(options) } - return createAcceptorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace) + return createAcceptorTransport( + hostAndPort, + P2P_PROTOCOLS, + options, + trustManagerFactory, + enableSSL, + threadPoolName, + trace, + remotingThreads + ) } fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true, threadPoolName: String = "P2PClient", - trace: Boolean = false): TransportConfiguration { + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { val options = mutableMapOf() if (enableSSL) { config?.addToTransportOptions(options) } - return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace) + return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace, remotingThreads) } fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: BrokerRpcSslOptions?, enableSSL: Boolean = true, - trace: Boolean = false): TransportConfiguration { + threadPoolName: String = "RPCServer", + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { val options = mutableMapOf() if (config != null && enableSSL) { config.keyStorePath.requireOnDefaultFileSystem() options.putAll(config.toTransportOptions()) } - return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCServer", trace) + return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, null, enableSSL, threadPoolName, trace, remotingThreads) } fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort, config: ClientRpcSslOptions?, enableSSL: Boolean = true, - trace: Boolean = false): TransportConfiguration { + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { val options = mutableMapOf() if (config != null && enableSSL) { config.trustStorePath.requireOnDefaultFileSystem() options.putAll(config.toTransportOptions()) } - return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace) + return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace, remotingThreads) } fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration, + threadPoolName: String = "Internal-RPCClient", trace: Boolean = false): TransportConfiguration { val options = mutableMapOf() config.addToTransportOptions(options) - return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace) + return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, threadPoolName, trace, null) } fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration, - trace: Boolean = false): TransportConfiguration { + threadPoolName: String = "Internal-RPCServer", + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { val options = mutableMapOf() config.addToTransportOptions(options) - return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCServer", trace) + return createAcceptorTransport( + hostAndPort, + RPC_PROTOCOLS, + options, + trustManagerFactory(requireNotNull(config.trustStore).get()), + true, + threadPoolName, + trace, + remotingThreads + ) } private fun createAcceptorTransport(hostAndPort: NetworkHostAndPort, protocols: String, options: MutableMap, + trustManagerFactory: TrustManagerFactory?, enableSSL: Boolean, threadPoolName: String, - trace: Boolean): TransportConfiguration { + trace: Boolean, + remotingThreads: Int?): TransportConfiguration { // Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections options[TransportConstants.HANDSHAKE_TIMEOUT] = 0 + if (trustManagerFactory != null) { + // NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use + // more customised instances which use our revocation checkers, so we pass them in, to be picked up by Node(Open)SSLContextFactory. + options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory + } return createTransport( "net.corda.node.services.messaging.NodeNettyAcceptorFactory", hostAndPort, @@ -171,7 +196,8 @@ class ArtemisTcpTransport { options, enableSSL, threadPoolName, - trace + trace, + remotingThreads ) } @@ -180,15 +206,21 @@ class ArtemisTcpTransport { options: MutableMap, enableSSL: Boolean, threadPoolName: String, - trace: Boolean): TransportConfiguration { + trace: Boolean, + remotingThreads: Int?): TransportConfiguration { + if (enableSSL) { + // This is required to stop Client checking URL address vs. Server provided certificate + options[TransportConstants.VERIFY_HOST_PROP_NAME] = false + } return createTransport( - NodeNettyConnectorFactory::class.java.name, + CordaNettyConnectorFactory::class.java.name, hostAndPort, protocols, options, enableSSL, threadPoolName, - trace + trace, + remotingThreads ) } @@ -198,13 +230,15 @@ class ArtemisTcpTransport { options: MutableMap, enableSSL: Boolean, threadPoolName: String, - trace: Boolean): TransportConfiguration { + trace: Boolean, + remotingThreads: Int?): TransportConfiguration { options += defaultArtemisOptions(hostAndPort, protocols) if (enableSSL) { - options += defaultSSLOptions - // This is required to stop Client checking URL address vs. Server provided certificate - options[TransportConstants.VERIFY_HOST_PROP_NAME] = false + options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",") + options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",") } + // By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357) + options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1 options[THREAD_POOL_NAME_NAME] = threadPoolName options[TRACE_NAME] = trace return TransportConfiguration(className, options) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisUtils.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisUtils.kt index 23bb9d1428..a3c2109d32 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisUtils.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisUtils.kt @@ -1,8 +1,14 @@ @file:JvmName("ArtemisUtils") package net.corda.nodeapi.internal +import net.corda.core.internal.declaredField +import org.apache.activemq.artemis.utils.actors.ProcessorBase import java.nio.file.FileSystems import java.nio.file.Path +import java.util.concurrent.Executor +import java.util.concurrent.ThreadFactory +import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.atomic.AtomicInteger /** * Require that the [Path] is on a default file system, and therefore is one that Artemis is willing to use. @@ -16,3 +22,29 @@ fun requireMessageSize(messageSize: Int, limit: Int) { require(messageSize <= limit) { "Message exceeds maxMessageSize network parameter, maxMessageSize: [$limit], message size: [$messageSize]" } } +val Executor.rootExecutor: Executor get() { + var executor: Executor = this + while (executor is ProcessorBase<*>) { + executor = executor.declaredField("delegate").value + } + return executor +} + +fun Executor.setThreadPoolName(threadPoolName: String) { + (rootExecutor as? ThreadPoolExecutor)?.let { it.threadFactory = NamedThreadFactory(threadPoolName, it.threadFactory) } +} + +private class NamedThreadFactory(poolName: String, private val delegate: ThreadFactory) : ThreadFactory { + companion object { + private val poolId = AtomicInteger(0) + } + + private val prefix = "$poolName-${poolId.incrementAndGet()}-" + private val nextId = AtomicInteger(0) + + override fun newThread(r: Runnable): Thread { + val thread = delegate.newThread(r) + thread.name = "$prefix${nextId.incrementAndGet()}" + return thread + } +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeNettyConnectorFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/CordaNettyConnectorFactory.kt similarity index 70% rename from node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeNettyConnectorFactory.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/CordaNettyConnectorFactory.kt index 47e046566e..a9bdc519a9 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeNettyConnectorFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/CordaNettyConnectorFactory.kt @@ -14,15 +14,16 @@ import org.apache.activemq.artemis.utils.ConfigurationHelper import java.util.concurrent.Executor import java.util.concurrent.ScheduledExecutorService -class NodeNettyConnectorFactory : ConnectorFactory { +class CordaNettyConnectorFactory : ConnectorFactory { override fun createConnector(configuration: MutableMap?, handler: BufferHandler?, listener: ClientConnectionLifeCycleListener?, - closeExecutor: Executor?, - threadPool: Executor?, - scheduledThreadPool: ScheduledExecutorService?, + closeExecutor: Executor, + threadPool: Executor, + scheduledThreadPool: ScheduledExecutorService, protocolManager: ClientProtocolManager?): Connector { val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Connector", configuration) + setThreadPoolName(threadPool, closeExecutor, scheduledThreadPool, threadPoolName) val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration) return NettyConnector( configuration, @@ -31,7 +32,7 @@ class NodeNettyConnectorFactory : ConnectorFactory { closeExecutor, threadPool, scheduledThreadPool, - MyClientProtocolManager(threadPoolName, trace) + MyClientProtocolManager("$threadPoolName-netty", trace) ) } @@ -39,6 +40,17 @@ class NodeNettyConnectorFactory : ConnectorFactory { override fun getDefaults(): Map = NettyConnector.DEFAULT_CONFIG + private fun setThreadPoolName(threadPool: Executor, closeExecutor: Executor, scheduledThreadPool: ScheduledExecutorService, name: String) { + threadPool.setThreadPoolName("$name-artemis") + // Artemis will actually wrap the same backing Executor to create multiple "OrderedExecutors". In this scenerio both the threadPool + // and the closeExecutor are the same when it comes to the pool names. If however they are different then given them separate names. + if (threadPool.rootExecutor !== closeExecutor.rootExecutor) { + closeExecutor.setThreadPoolName("$name-artemis-closer") + } + // The scheduler is separate + scheduledThreadPool.setThreadPoolName("$name-artemis-scheduler") + } + private class MyClientProtocolManager(private val threadPoolName: String, private val trace: Boolean) : ActiveMQClientProtocolManager() { override fun addChannelHandlers(pipeline: ChannelPipeline) { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeApiUtils.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeApiUtils.kt new file mode 100644 index 0000000000..65d60ab38d --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeApiUtils.kt @@ -0,0 +1,32 @@ +@file:Suppress("LongParameterList", "MagicNumber") + +package net.corda.nodeapi.internal + +import io.netty.util.concurrent.DefaultThreadFactory +import net.corda.core.utilities.seconds +import java.time.Duration +import java.util.concurrent.BlockingQueue +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.TimeUnit + +/** + * Creates a [ThreadPoolExecutor] which will use a maximum of [maxPoolSize] threads at any given time and will by default idle down to 0 + * threads. + */ +fun namedThreadPoolExecutor(maxPoolSize: Int, + corePoolSize: Int = 0, + idleKeepAlive: Duration = 30.seconds, + workQueue: BlockingQueue = LinkedBlockingQueue(), + poolName: String = "pool", + daemonThreads: Boolean = false, + threadPriority: Int = Thread.NORM_PRIORITY): ThreadPoolExecutor { + return ThreadPoolExecutor( + corePoolSize, + maxPoolSize, + idleKeepAlive.toNanos(), + TimeUnit.NANOSECONDS, + workQueue, + DefaultThreadFactory(poolName, daemonThreads, threadPriority) + ) +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt index ee09e640ad..93ab5616de 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt @@ -22,6 +22,7 @@ import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE @@ -31,6 +32,7 @@ import org.apache.activemq.artemis.api.core.client.ClientSession import org.slf4j.MDC import rx.Subscription import java.time.Duration +import java.util.concurrent.ExecutorService import java.util.concurrent.Executors import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.ScheduledFuture @@ -53,7 +55,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, maxMessageSize: Int, revocationConfig: RevocationConfig, enableSNI: Boolean, - private val artemisMessageClientFactory: () -> ArtemisSessionProvider, + private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider, private val bridgeMetricsService: BridgeMetricsService? = null, trace: Boolean, sslHandshakeTimeout: Duration?, @@ -78,9 +80,11 @@ open class AMQPBridgeManager(keyStore: CertificateStore, private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig,useOpenSSL, enableSNI, trace = trace, _sslHandshakeTimeout = sslHandshakeTimeout) private var sharedEventLoopGroup: EventLoopGroup? = null + private var sslDelegatedTaskExecutor: ExecutorService? = null private var artemis: ArtemisSessionProvider? = null companion object { + private val log = contextLogger() private const val CORDA_NUM_BRIDGE_THREADS_PROP_NAME = "net.corda.nodeapi.amqpbridgemanager.NumBridgeThreads" @@ -97,18 +101,11 @@ open class AMQPBridgeManager(keyStore: CertificateStore, * however Artemis and the remote Corda instanced will deduplicate these messages. */ @Suppress("TooManyFunctions") - private class AMQPBridge(val sourceX500Name: String, - val queueName: String, - val targets: List, - val legalNames: Set, - private val amqpConfig: AMQPConfiguration, - sharedEventGroup: EventLoopGroup, - private val artemis: ArtemisSessionProvider, - private val bridgeMetricsService: BridgeMetricsService?, - private val bridgeConnectionTTLSeconds: Int) { - companion object { - private val log = contextLogger() - } + private inner class AMQPBridge(val sourceX500Name: String, + val queueName: String, + val targets: List, + val allowedRemoteLegalNames: Set, + private val amqpConfig: AMQPConfiguration) { private fun withMDC(block: () -> Unit) { val oldMDC = MDC.getCopyOfContextMap() ?: emptyMap() @@ -116,7 +113,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, MDC.put("queueName", queueName) MDC.put("source", amqpConfig.sourceX500Name) MDC.put("targets", targets.joinToString(separator = ";") { it.toString() }) - MDC.put("legalNames", legalNames.joinToString(separator = ";") { it.toString() }) + MDC.put("allowedRemoteLegalNames", allowedRemoteLegalNames.joinToString(separator = ";") { it.toString() }) MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString()) block() } finally { @@ -134,13 +131,18 @@ open class AMQPBridgeManager(keyStore: CertificateStore, private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) } - val amqpClient = AMQPClient(targets, legalNames, amqpConfig, sharedThreadPool = sharedEventGroup) + val amqpClient = AMQPClient( + targets, + allowedRemoteLegalNames, + amqpConfig, + AMQPClient.NettyThreading.Shared(sharedEventLoopGroup!!, sslDelegatedTaskExecutor!!) + ) private var session: ClientSession? = null private var consumer: ClientConsumer? = null private var connectedSubscription: Subscription? = null @Volatile private var messagesReceived: Boolean = false - private val eventLoop: EventLoop = sharedEventGroup.next() + private val eventLoop: EventLoop = sharedEventLoopGroup!!.next() private var artemisState: ArtemisState = ArtemisState.STOPPED set(value) { logDebugWithMDC { "State change $field to $value" } @@ -152,32 +154,9 @@ open class AMQPBridgeManager(keyStore: CertificateStore, private var scheduledExecutorService: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor(ThreadFactoryBuilder().setNameFormat("bridge-connection-reset-%d").build()) - @Suppress("ClassNaming") - private sealed class ArtemisState { - object STARTING : ArtemisState() - data class STARTED(override val pending: ScheduledFuture) : ArtemisState() - - object CHECKING : ArtemisState() - object RESTARTED : ArtemisState() - object RECEIVING : ArtemisState() - - object AMQP_STOPPED : ArtemisState() - object AMQP_STARTING : ArtemisState() - object AMQP_STARTED : ArtemisState() - object AMQP_RESTARTED : ArtemisState() - - object STOPPING : ArtemisState() - object STOPPED : ArtemisState() - data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture) : ArtemisState() - - open val pending: ScheduledFuture? = null - - override fun toString(): String = javaClass.simpleName - } - private fun artemis(inProgress: ArtemisState, block: (precedingState: ArtemisState) -> ArtemisState) { val runnable = { - synchronized(artemis) { + synchronized(artemis!!) { try { val precedingState = artemisState artemisState.pending?.cancel(false) @@ -231,7 +210,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, } ArtemisState.STOPPING } - bridgeMetricsService?.bridgeDisconnected(targets, legalNames) + bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames) connectedSubscription?.unsubscribe() connectedSubscription = null // Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first. @@ -243,7 +222,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, if (connected) { logInfoWithMDC("Bridge Connected") - bridgeMetricsService?.bridgeConnected(targets, legalNames) + bridgeMetricsService?.bridgeConnected(targets, allowedRemoteLegalNames) if (bridgeConnectionTTLSeconds > 0) { // AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS, @@ -253,7 +232,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, } } artemis(ArtemisState.STARTING) { - val startedArtemis = artemis.started + val startedArtemis = artemis!!.started if (startedArtemis == null) { logInfoWithMDC("Bridge Connected but Artemis is disconnected") ArtemisState.STOPPED @@ -286,7 +265,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, logInfoWithMDC("Bridge Disconnected") amqpRestartEvent?.cancel(false) if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) { - bridgeMetricsService?.bridgeDisconnected(targets, legalNames) + bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames) } artemis(ArtemisState.STOPPING) { precedingState: ArtemisState -> logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected") @@ -418,10 +397,10 @@ open class AMQPBridgeManager(keyStore: CertificateStore, properties[key] = value } } - logDebugWithMDC { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" } + logDebugWithMDC { "Bridged Send to ${allowedRemoteLegalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" } val peerInbox = translateLocalQueueToInboxAddress(queueName) val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox, - legalNames.first().toString(), + allowedRemoteLegalNames.first().toString(), properties) sendableMessage.onComplete.then { logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" } @@ -457,6 +436,29 @@ open class AMQPBridgeManager(keyStore: CertificateStore, } } + @Suppress("ClassNaming") + private sealed class ArtemisState { + object STARTING : ArtemisState() + data class STARTED(override val pending: ScheduledFuture) : ArtemisState() + + object CHECKING : ArtemisState() + object RESTARTED : ArtemisState() + object RECEIVING : ArtemisState() + + object AMQP_STOPPED : ArtemisState() + object AMQP_STARTING : ArtemisState() + object AMQP_STARTED : ArtemisState() + object AMQP_RESTARTED : ArtemisState() + + object STOPPING : ArtemisState() + object STOPPED : ArtemisState() + data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture) : ArtemisState() + + open val pending: ScheduledFuture? = null + + override fun toString(): String = javaClass.simpleName + } + override fun deployBridge(sourceX500Name: String, queueName: String, targets: List, legalNames: Set) { lock.withLock { val bridges = queueNamesToBridgesMap.getOrPut(queueName) { mutableListOf() } @@ -467,8 +469,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, } val newAMQPConfig = with(amqpConfig) { AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig, useOpenSsl, enableSNI, sourceX500Name, trace, sslHandshakeTimeout) } - val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig, sharedEventLoopGroup!!, artemis!!, - bridgeMetricsService, bridgeConnectionTTLSeconds) + val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig) bridges += newBridge bridgeMetricsService?.bridgeCreated(targets, legalNames) newBridge @@ -486,7 +487,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, queueNamesToBridgesMap.remove(queueName) } bridge.stop() - bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.legalNames) + bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.allowedRemoteLegalNames) } } } @@ -497,15 +498,16 @@ open class AMQPBridgeManager(keyStore: CertificateStore, // queueNamesToBridgesMap returns a mutable list, .toList converts it to a immutable list so it won't be changed by the [destroyBridge] method. val bridges = queueNamesToBridgesMap[queueName]?.toList() destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList()) - bridges?.map { - it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.legalNames.toList(), serviceAddress = false) - }?.toMap() ?: emptyMap() + bridges?.associate { + it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.allowedRemoteLegalNames.toList(), serviceAddress = false) + } ?: emptyMap() } } override fun start() { - sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS, DefaultThreadFactory("AMQPBridge", Thread.MAX_PRIORITY)) - val artemis = artemisMessageClientFactory() + sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS, DefaultThreadFactory("NettyBridge", Thread.MAX_PRIORITY)) + sslDelegatedTaskExecutor = sslDelegatedTaskExecutor("NettyBridge") + val artemis = artemisMessageClientFactory("ArtemisBridge") this.artemis = artemis artemis.start() } @@ -522,6 +524,8 @@ open class AMQPBridgeManager(keyStore: CertificateStore, sharedEventLoopGroup = null queueNamesToBridgesMap.clear() artemis?.stop() + sslDelegatedTaskExecutor?.shutdown() + sslDelegatedTaskExecutor = null } } } \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt index 0fee8f1fba..357088bc0a 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt @@ -35,7 +35,7 @@ class BridgeControlListener(private val keyStore: CertificateStore, maxMessageSize: Int, revocationConfig: RevocationConfig, enableSNI: Boolean, - private val artemisMessageClientFactory: () -> ArtemisSessionProvider, + private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider, bridgeMetricsService: BridgeMetricsService? = null, trace: Boolean = false, sslHandshakeTimeout: Duration? = null, @@ -80,7 +80,7 @@ class BridgeControlListener(private val keyStore: CertificateStore, bridgeNotifyQueue = "$BRIDGE_NOTIFY.$queueDisambiguityId" bridgeManager.start() - val artemis = artemisMessageClientFactory() + val artemis = artemisMessageClientFactory("BridgeControl") this.artemis = artemis artemis.start() val artemisClient = artemis.started!! diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/LoopbackBridgeManager.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/LoopbackBridgeManager.kt index e9ac1ca522..2dd9f8bff0 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/LoopbackBridgeManager.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/LoopbackBridgeManager.kt @@ -37,7 +37,7 @@ class LoopbackBridgeManager(keyStore: CertificateStore, maxMessageSize: Int, revocationConfig: RevocationConfig, enableSNI: Boolean, - private val artemisMessageClientFactory: () -> ArtemisSessionProvider, + private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider, private val bridgeMetricsService: BridgeMetricsService? = null, private val isLocalInbox: (String) -> Boolean, trace: Boolean, @@ -204,7 +204,7 @@ class LoopbackBridgeManager(keyStore: CertificateStore, override fun start() { super.start() - val artemis = artemisMessageClientFactory() + val artemis = artemisMessageClientFactory("LoopbackBridge") this.artemis = artemis artemis.start() } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt index 33036f95ff..d617b7fb0f 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt @@ -386,7 +386,7 @@ object X509Utilities { private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) { if (crlDistPoint != null) { - val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint))) + val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier)) val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(crlIssuer)) } @@ -406,6 +406,8 @@ object X509Utilities { bytes[0] = bytes[0].and(0x3F).or(0x40) return BigInteger(bytes) } + + fun toGeneralNames(string: String, tag: Int = GeneralName.directoryName): GeneralNames = GeneralNames(GeneralName(tag, string)) } // Assuming cert type to role is 1:1 diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt index 7920f6f3e9..a0735a9ba0 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt @@ -27,16 +27,17 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME import net.corda.nodeapi.internal.requireMessageSize +import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource import rx.Observable import rx.subjects.PublishSubject import java.lang.Long.min import java.net.InetSocketAddress -import java.security.cert.CertPathValidatorException +import java.util.concurrent.Executor +import java.util.concurrent.ExecutorService +import java.util.concurrent.ThreadPoolExecutor import java.time.Duration import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock -import javax.net.ssl.KeyManagerFactory -import javax.net.ssl.TrustManagerFactory import kotlin.concurrent.withLock enum class ProxyVersion { @@ -63,8 +64,8 @@ data class ProxyConfig(val version: ProxyVersion, val proxyAddress: NetworkHostA class AMQPClient(private val targets: List, val allowedRemoteLegalNames: Set, private val configuration: AMQPConfiguration, - private val sharedThreadPool: EventLoopGroup? = null, - private val threadPoolName: String = "AMQPClient") : AutoCloseable { + private val nettyThreading: NettyThreading = NettyThreading.NonShared("AMQPClient"), + private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON) : AutoCloseable { companion object { init { InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) @@ -84,14 +85,12 @@ class AMQPClient(private val targets: List, private val lock = ReentrantLock() @Volatile private var started: Boolean = false - private var workerGroup: EventLoopGroup? = null @Volatile private var clientChannel: Channel? = null // Offset into the list of targets, so that we can implement round-robin reconnect logic. private var targetIndex = 0 private var currentTarget: NetworkHostAndPort = targets.first() private var retryInterval = MIN_RETRY_INTERVAL - private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker() private val handshakeFailureRetryTargets = mutableSetOf() private var retryingHandshakeFailures = false private var retryOffset = 0 @@ -172,7 +171,7 @@ class AMQPClient(private val targets: List, log.info("Failed to connect to $currentTarget", future.cause()) if (started) { - workerGroup?.schedule({ + nettyThreading.eventLoopGroup.schedule({ nextTarget() restart() }, retryInterval, TimeUnit.MILLISECONDS) @@ -191,7 +190,7 @@ class AMQPClient(private val targets: List, clientChannel = null if (started && !amqpActive) { log.debug { "Scheduling restart of $currentTarget (AMQP inactive)" } - workerGroup?.schedule({ + nettyThreading.eventLoopGroup.schedule({ nextTarget() restart() }, retryInterval, TimeUnit.MILLISECONDS) @@ -199,17 +198,16 @@ class AMQPClient(private val targets: List, } private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer() { - private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore) + private val trustManagerFactory = trustManagerFactoryWithRevocation( + parent.configuration.trustStore, + parent.configuration.revocationConfig, + parent.distPointCrlSource + ) private val conf = parent.configuration @Volatile private lateinit var amqpChannelHandler: AMQPChannelHandler - init { - keyManagerFactory.init(conf.keyStore) - trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker)) - } - @Suppress("ComplexMethod") override fun initChannel(ch: SocketChannel) { val pipeline = ch.pipeline() @@ -249,9 +247,22 @@ class AMQPClient(private val targets: List, val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration) val target = parent.currentTarget val handler = if (parent.configuration.useOpenSsl) { - createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc()) + createClientOpenSslHandler( + target, + parent.allowedRemoteLegalNames, + wrappedKeyManagerFactory, + trustManagerFactory, + ch.alloc(), + parent.nettyThreading.sslDelegatedTaskExecutor + ) } else { - createClientSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory) + createClientSslHandler( + target, + parent.allowedRemoteLegalNames, + wrappedKeyManagerFactory, + trustManagerFactory, + parent.nettyThreading.sslDelegatedTaskExecutor + ) } handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis() pipeline.addLast("sslHandler", handler) @@ -292,7 +303,7 @@ class AMQPClient(private val targets: List, if (started && amqpActive) { log.debug { "Scheduling restart of $currentTarget (AMQP active)" } - workerGroup?.schedule({ + nettyThreading.eventLoopGroup.schedule({ nextTarget() restart() }, retryInterval, TimeUnit.MILLISECONDS) @@ -309,7 +320,7 @@ class AMQPClient(private val targets: List, return } log.info("Connect to: $currentTarget") - workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY)) + (nettyThreading as? NettyThreading.NonShared)?.start() started = true restart() } @@ -321,7 +332,7 @@ class AMQPClient(private val targets: List, } val bootstrap = Bootstrap() // TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux - bootstrap.group(workerGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this)) + bootstrap.group(nettyThreading.eventLoopGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this)) // Delegate DNS Resolution to the proxy side, if we are using proxy. if (configuration.proxyConfig != null) { bootstrap.resolver(NoopAddressResolverGroup.INSTANCE) @@ -335,14 +346,12 @@ class AMQPClient(private val targets: List, lock.withLock { log.info("Stopping connection to: $currentTarget, Local address: $localAddressString") started = false - if (sharedThreadPool == null) { - workerGroup?.shutdownGracefully() - workerGroup?.terminationFuture()?.sync() + if (nettyThreading is NettyThreading.NonShared) { + nettyThreading.stop() } else { clientChannel?.close()?.sync() } clientChannel = null - workerGroup = null log.info("Stopped connection to $currentTarget") } } @@ -384,5 +393,35 @@ class AMQPClient(private val targets: List, val onConnection: Observable get() = _onConnection - val softFailExceptions: List get() = revocationChecker.softFailExceptions -} \ No newline at end of file + + sealed class NettyThreading { + abstract val eventLoopGroup: EventLoopGroup + abstract val sslDelegatedTaskExecutor: Executor + + class Shared(override val eventLoopGroup: EventLoopGroup, + override val sslDelegatedTaskExecutor: ExecutorService = sslDelegatedTaskExecutor("AMQPClient")) : NettyThreading() + + class NonShared(val threadPoolName: String) : NettyThreading() { + private var _eventLoopGroup: NioEventLoopGroup? = null + override val eventLoopGroup: EventLoopGroup get() = checkNotNull(_eventLoopGroup) + + private var _sslDelegatedTaskExecutor: ThreadPoolExecutor? = null + override val sslDelegatedTaskExecutor: ExecutorService get() = checkNotNull(_sslDelegatedTaskExecutor) + + fun start() { + check(_eventLoopGroup == null) + check(_sslDelegatedTaskExecutor == null) + _eventLoopGroup = NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY)) + _sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName) + } + + fun stop() { + eventLoopGroup.shutdownGracefully() + eventLoopGroup.terminationFuture().sync() + sslDelegatedTaskExecutor.shutdown() + _eventLoopGroup = null + _sslDelegatedTaskExecutor = null + } + } + } +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt index cbeb2562b4..523cde184a 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt @@ -21,16 +21,15 @@ import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.requireMessageSize +import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource import org.apache.qpid.proton.engine.Delivery import rx.Observable import rx.subjects.PublishSubject import java.net.BindException import java.net.InetSocketAddress -import java.security.cert.CertPathValidatorException import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ExecutorService import java.util.concurrent.locks.ReentrantLock -import javax.net.ssl.KeyManagerFactory -import javax.net.ssl.TrustManagerFactory import kotlin.concurrent.withLock /** @@ -39,37 +38,34 @@ import kotlin.concurrent.withLock class AMQPServer(val hostName: String, val port: Int, private val configuration: AMQPConfiguration, - private val threadPoolName: String = "AMQPServer") : AutoCloseable { + private val threadPoolName: String = "AMQPServer", + private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON, + private val remotingThreads: Int? = null) : AutoCloseable { companion object { init { InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) } - private const val CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME = "net.corda.nodeapi.amqpserver.NumServerThreads" - private val log = contextLogger() - private val NUM_SERVER_THREADS = Integer.getInteger(CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME, 4) + private val DEFAULT_REMOTING_THREADS = Integer.getInteger("net.corda.nodeapi.amqpserver.NumServerThreads", 4) } private val lock = ReentrantLock() - @Volatile - private var stopping: Boolean = false private var bossGroup: EventLoopGroup? = null private var workerGroup: EventLoopGroup? = null private var serverChannel: Channel? = null - private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker() + private var sslDelegatedTaskExecutor: ExecutorService? = null private val clientChannels = ConcurrentHashMap() private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer() { - private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore) + private val trustManagerFactory = trustManagerFactoryWithRevocation( + parent.configuration.trustStore, + parent.configuration.revocationConfig, + parent.distPointCrlSource + ) private val conf = parent.configuration - init { - keyManagerFactory.init(conf.keyStore.value.internal, conf.keyStore.entryPassword.toCharArray()) - trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker)) - } - override fun initChannel(ch: SocketChannel) { val amqpConfiguration = parent.configuration val pipeline = ch.pipeline() @@ -116,11 +112,12 @@ class AMQPServer(val hostName: String, Pair(createServerSNIOpenSniHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap) } else { val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig) + val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor) val handler = if (amqpConfig.useOpenSsl) { - createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc()) + createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc(), delegatedTaskExecutor) } else { // For javaSSL, SNI matching is handled at key manager level. - createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory) + createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory, delegatedTaskExecutor) } handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout.toMillis() Pair(handler, mapOf(DEFAULT to keyManagerFactory)) @@ -132,8 +129,13 @@ class AMQPServer(val hostName: String, lock.withLock { stop() + sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName) + bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("$threadPoolName-boss", Thread.MAX_PRIORITY)) - workerGroup = NioEventLoopGroup(NUM_SERVER_THREADS, DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY)) + workerGroup = NioEventLoopGroup( + remotingThreads ?: DEFAULT_REMOTING_THREADS, + DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY) + ) val server = ServerBootstrap() // TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux @@ -154,22 +156,19 @@ class AMQPServer(val hostName: String, fun stop() { lock.withLock { - try { - stopping = true - serverChannel?.apply { close() } - serverChannel = null + serverChannel?.close() + serverChannel = null - workerGroup?.shutdownGracefully() - workerGroup?.terminationFuture()?.sync() + workerGroup?.shutdownGracefully() + workerGroup?.terminationFuture()?.sync() + workerGroup = null - bossGroup?.shutdownGracefully() - bossGroup?.terminationFuture()?.sync() + bossGroup?.shutdownGracefully() + bossGroup?.terminationFuture()?.sync() + bossGroup = null - workerGroup = null - bossGroup = null - } finally { - stopping = false - } + sslDelegatedTaskExecutor?.shutdown() + sslDelegatedTaskExecutor = null } } @@ -226,6 +225,4 @@ class AMQPServer(val hostName: String, private val _onConnection = PublishSubject.create().toSerialized() val onConnection: Observable get() = _onConnection - - val softFailExceptions: List get() = revocationChecker.softFailExceptions } \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AllowAllRevocationChecker.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AllowAllRevocationChecker.kt index 538bb17a1e..a853cbffc8 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AllowAllRevocationChecker.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AllowAllRevocationChecker.kt @@ -31,4 +31,6 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() { override fun getSoftFailExceptions(): List { return Collections.emptyList() } + + override fun clone(): AllowAllRevocationChecker = this } \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/RevocationConfig.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/RevocationConfig.kt index 6e8c695387..4e1b4b1930 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/RevocationConfig.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/RevocationConfig.kt @@ -3,9 +3,6 @@ package net.corda.nodeapi.internal.protonwrapper.netty import com.typesafe.config.Config import net.corda.nodeapi.internal.config.ConfigParser import net.corda.nodeapi.internal.config.CustomConfigParser -import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource -import net.corda.nodeapi.internal.revocation.CordaRevocationChecker -import java.security.cert.PKIXRevocationChecker /** * Data structure for controlling the way how Certificate Revocation Lists are handled. @@ -45,18 +42,6 @@ interface RevocationConfig { * Optional [CrlSource] which only makes sense with `mode` = `EXTERNAL_SOURCE` */ val externalCrlSource: CrlSource? - - fun createPKIXRevocationChecker(): PKIXRevocationChecker { - return when (mode) { - Mode.OFF -> AllowAllRevocationChecker - Mode.EXTERNAL_SOURCE -> { - val externalCrlSource = requireNotNull(externalCrlSource) { "externalCrlSource must be specfied for EXTERNAL_SOURCE" } - CordaRevocationChecker(externalCrlSource, softFail = true) - } - Mode.SOFT_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = true) - Mode.HARD_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = false) - } - } } /** diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt index ad5260f9a6..6d8bc6b344 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt @@ -1,3 +1,5 @@ +@file:Suppress("ComplexMethod", "LongParameterList") + package net.corda.nodeapi.internal.protonwrapper.netty import io.netty.buffer.ByteBufAllocator @@ -18,6 +20,8 @@ import net.corda.nodeapi.internal.ArtemisTcpTransport import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.crypto.toSimpleString import net.corda.nodeapi.internal.crypto.x509 +import net.corda.nodeapi.internal.namedThreadPoolExecutor +import net.corda.nodeapi.internal.revocation.CordaRevocationChecker import org.bouncycastle.asn1.ASN1InputStream import org.bouncycastle.asn1.ASN1Primitive import org.bouncycastle.asn1.ASN1IA5String @@ -34,10 +38,10 @@ import java.net.URI import java.security.KeyStore import java.security.cert.CertificateException import java.security.cert.PKIXBuilderParameters -import java.security.cert.PKIXRevocationChecker import java.security.cert.X509CertSelector import java.security.cert.X509Certificate import java.util.concurrent.Executor +import java.util.concurrent.ThreadPoolExecutor import javax.net.ssl.CertPathTrustManagerParameters import javax.net.ssl.KeyManagerFactory import javax.net.ssl.SNIHostName @@ -46,7 +50,6 @@ import javax.net.ssl.SSLEngine import javax.net.ssl.TrustManagerFactory import javax.net.ssl.X509ExtendedTrustManager import javax.security.auth.x500.X500Principal -import kotlin.system.measureTimeMillis private const val HOSTNAME_FORMAT = "%s.corda.net" internal const val DEFAULT = "default" @@ -58,7 +61,6 @@ internal val logger = LoggerFactory.getLogger("net.corda.nodeapi.internal.proton /** * Returns all the CRL distribution points in the certificate as [URI]s along with the CRL issuer names, if any. */ -@Suppress("ComplexMethod") fun X509Certificate.distributionPoints(): Map?> { logger.debug { "Checking CRLDPs for $subjectX500Principal" } @@ -117,6 +119,14 @@ fun certPathToString(certPath: Array?): String { return certPath.joinToString(System.lineSeparator()) { " ${it.toSimpleString()}" } } +/** + * Create an executor for processing SSL handshake tasks asynchronously (see [SSLEngine.getDelegatedTask]). The max number of threads is 3, + * which is the typical number of CRLs expected in a Corda TLS cert path. The executor needs to be passed to the [SslHandler] constructor. + */ +fun sslDelegatedTaskExecutor(parentPoolName: String): ThreadPoolExecutor { + return namedThreadPoolExecutor(maxPoolSize = 3, poolName = "$parentPoolName-ssltask") +} + @VisibleForTesting class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() { companion object { @@ -179,32 +189,11 @@ class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509Ex } -private object LoggingImmediateExecutor : Executor { - - override fun execute(command: Runnable) { - val log = LoggerFactory.getLogger(javaClass) - - @Suppress("TooGenericExceptionCaught", "MagicNumber") // log and rethrow all exceptions - try { - val commandName = command::class.qualifiedName?.let { "[$it]" } ?: "" - log.debug("Entering SSL command $commandName") - val elapsedTime = measureTimeMillis { command.run() } - log.debug("Exiting SSL command $elapsedTime millis") - if (elapsedTime > 100) { - log.info("Command: $commandName took $elapsedTime millis to execute") - } - } - catch (ex: Exception) { - log.error("Caught exception in SSL handler executor", ex) - throw ex - } - } -} - internal fun createClientSslHandler(target: NetworkHostAndPort, expectedRemoteLegalNames: Set, keyManagerFactory: KeyManagerFactory, - trustManagerFactory: TrustManagerFactory): SslHandler { + trustManagerFactory: TrustManagerFactory, + delegateTaskExecutor: Executor): SslHandler { val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory) val sslEngine = sslContext.createSSLEngine(target.host, target.port) sslEngine.useClientMode = true @@ -216,14 +205,15 @@ internal fun createClientSslHandler(target: NetworkHostAndPort, sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) sslEngine.sslParameters = sslParameters } - return SslHandler(sslEngine, false, LoggingImmediateExecutor) + return SslHandler(sslEngine, false, delegateTaskExecutor) } internal fun createClientOpenSslHandler(target: NetworkHostAndPort, expectedRemoteLegalNames: Set, keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory, - alloc: ByteBufAllocator): SslHandler { + alloc: ByteBufAllocator, + delegateTaskExecutor: Executor): SslHandler { val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build() val sslEngine = sslContext.newEngine(alloc, target.host, target.port) sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray() @@ -233,12 +223,13 @@ internal fun createClientOpenSslHandler(target: NetworkHostAndPort, sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) sslEngine.sslParameters = sslParameters } - return SslHandler(sslEngine, false, LoggingImmediateExecutor) + return SslHandler(sslEngine, false, delegateTaskExecutor) } internal fun createServerSslHandler(keyStore: CertificateStore, keyManagerFactory: KeyManagerFactory, - trustManagerFactory: TrustManagerFactory): SslHandler { + trustManagerFactory: TrustManagerFactory, + delegateTaskExecutor: Executor): SslHandler { val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory) val sslEngine = sslContext.createSSLEngine() sslEngine.useClientMode = false @@ -249,39 +240,29 @@ internal fun createServerSslHandler(keyStore: CertificateStore, val sslParameters = sslEngine.sslParameters sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore)) sslEngine.sslParameters = sslParameters - return SslHandler(sslEngine, false, LoggingImmediateExecutor) + return SslHandler(sslEngine, false, delegateTaskExecutor) } internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory, - alloc: ByteBufAllocator): SslHandler { + alloc: ByteBufAllocator, + delegateTaskExecutor: Executor): SslHandler { val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build() val sslEngine = sslContext.newEngine(alloc) sslEngine.useClientMode = false - return SslHandler(sslEngine, false, LoggingImmediateExecutor) + return SslHandler(sslEngine, false, delegateTaskExecutor) } -fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SSLContext { +fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory?): SSLContext { val sslContext = SSLContext.getInstance("TLS") - val keyManagers = keyManagerFactory.keyManagers - val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java) - .map { LoggingTrustManagerWrapper(it) }.toTypedArray() - sslContext.init(keyManagers, trustManagers, newSecureRandom()) + val trustManagers = trustManagerFactory + ?.trustManagers + ?.map { if (it is X509ExtendedTrustManager) LoggingTrustManagerWrapper(it) else it } + ?.toTypedArray() + sslContext.init(keyManagerFactory.keyManagers, trustManagers, newSecureRandom()) return sslContext } -fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore, - revocationConfig: RevocationConfig): CertPathTrustManagerParameters { - return initialiseTrustStoreAndEnableCrlChecking(trustStore, revocationConfig.createPKIXRevocationChecker()) -} - -fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore, - revocationChecker: PKIXRevocationChecker): CertPathTrustManagerParameters { - val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector()) - pkixParams.addCertPathChecker(revocationChecker) - return CertPathTrustManagerParameters(pkixParams) -} - /** * Creates a special SNI handler used only when openSSL is used for AMQPServer */ @@ -296,14 +277,13 @@ internal fun createServerSNIOpenSniHandler(keyManagerFactoriesMap: Map { @@ -327,7 +307,38 @@ internal fun splitKeystore(config: AMQPConfiguration): Map AllowAllRevocationChecker + RevocationConfig.Mode.EXTERNAL_SOURCE -> { + val externalCrlSource = requireNotNull(revocationConfig.externalCrlSource) { + "externalCrlSource must be specfied for EXTERNAL_SOURCE" + } + CordaRevocationChecker(externalCrlSource, softFail = true) + } + RevocationConfig.Mode.SOFT_FAIL -> CordaRevocationChecker(crlSource, softFail = true) + RevocationConfig.Mode.HARD_FAIL -> CordaRevocationChecker(crlSource, softFail = false) + } + val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector()) + pkixParams.addCertPathChecker(revocationChecker) + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + trustManagerFactory.init(CertPathTrustManagerParameters(pkixParams)) + return trustManagerFactory +} /** * Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSource.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSource.kt index db984f11b8..ee589e73a9 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSource.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSource.kt @@ -5,6 +5,9 @@ import com.github.benmanes.caffeine.cache.LoadingCache import net.corda.core.internal.readFully import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.minutes +import net.corda.core.utilities.seconds import net.corda.nodeapi.internal.crypto.X509CertificateFactory import net.corda.nodeapi.internal.crypto.toSimpleString import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource @@ -12,60 +15,71 @@ import net.corda.nodeapi.internal.protonwrapper.netty.distributionPoints import java.net.URI import java.security.cert.X509CRL import java.security.cert.X509Certificate -import java.util.concurrent.TimeUnit +import java.time.Duration import javax.security.auth.x500.X500Principal /** - * [CrlSource] which downloads CRLs from the distribution points in the X509 certificate. + * [CrlSource] which downloads CRLs from the distribution points in the X509 certificate and caches them. */ @Suppress("TooGenericExceptionCaught") -class CertDistPointCrlSource : CrlSource { +class CertDistPointCrlSource(cacheSize: Long = DEFAULT_CACHE_SIZE, + cacheExpiry: Duration = DEFAULT_CACHE_EXPIRY, + private val connectTimeout: Duration = DEFAULT_CONNECT_TIMEOUT, + private val readTimeout: Duration = DEFAULT_READ_TIMEOUT) : CrlSource { companion object { private val logger = contextLogger() // The default SSL handshake timeout is 60s (DEFAULT_SSL_HANDSHAKE_TIMEOUT). Considering there are 3 CRLs endpoints to check in a // node handshake, we want to keep the total timeout within that. - private const val DEFAULT_CONNECT_TIMEOUT = 9_000 - private const val DEFAULT_READ_TIMEOUT = 9_000 + private val DEFAULT_CONNECT_TIMEOUT = 9.seconds + private val DEFAULT_READ_TIMEOUT = 9.seconds private const val DEFAULT_CACHE_SIZE = 185L // Same default as the JDK (URICertStore) - private const val DEFAULT_CACHE_EXPIRY = 5 * 60 * 1000L + private val DEFAULT_CACHE_EXPIRY = 5.minutes - private val cache: LoadingCache = Caffeine.newBuilder() - .maximumSize(java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE)) - .expireAfterWrite(java.lang.Long.getLong("net.corda.dpcrl.cache.expiry", DEFAULT_CACHE_EXPIRY), TimeUnit.MILLISECONDS) - .build(::retrieveCRL) + val SINGLETON = CertDistPointCrlSource( + cacheSize = java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE), + cacheExpiry = java.lang.Long.getLong("net.corda.dpcrl.cache.expiry")?.let(Duration::ofMillis) ?: DEFAULT_CACHE_EXPIRY, + connectTimeout = java.lang.Long.getLong("net.corda.dpcrl.connect.timeout")?.let(Duration::ofMillis) ?: DEFAULT_CONNECT_TIMEOUT, + readTimeout = java.lang.Long.getLong("net.corda.dpcrl.read.timeout")?.let(Duration::ofMillis) ?: DEFAULT_READ_TIMEOUT + ) + } - private val connectTimeout = Integer.getInteger("net.corda.dpcrl.connect.timeout", DEFAULT_CONNECT_TIMEOUT) - private val readTimeout = Integer.getInteger("net.corda.dpcrl.read.timeout", DEFAULT_READ_TIMEOUT) + private val cache: LoadingCache = Caffeine.newBuilder() + .maximumSize(cacheSize) + .expireAfterWrite(cacheExpiry) + .build(::retrieveCRL) - private fun retrieveCRL(uri: URI): X509CRL { - val start = System.currentTimeMillis() - val bytes = try { - val conn = uri.toURL().openConnection() - conn.connectTimeout = connectTimeout - conn.readTimeout = readTimeout - // Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes - // in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException - // into CRLException and drops the cause chain. - conn.getInputStream().readFully() - } catch (e: Exception) { - if (logger.isDebugEnabled) { - logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e) - } - throw e + private fun retrieveCRL(uri: URI): X509CRL { + val start = System.currentTimeMillis() + val bytes = try { + val conn = uri.toURL().openConnection() + conn.connectTimeout = connectTimeout.toMillis().toInt() + conn.readTimeout = readTimeout.toMillis().toInt() + // Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes + // in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException + // into CRLException and drops the cause chain. + conn.getInputStream().readFully() + } catch (e: Exception) { + if (logger.isDebugEnabled) { + logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e) } - val duration = System.currentTimeMillis() - start - val crl = try { - X509CertificateFactory().generateCRL(bytes.inputStream()) - } catch (e: Exception) { - if (logger.isDebugEnabled) { - logger.debug("Invalid CRL from $uri (${duration}ms)", e) - } - throw e - } - logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" } - return crl + throw e } + val duration = System.currentTimeMillis() - start + val crl = try { + X509CertificateFactory().generateCRL(bytes.inputStream()) + } catch (e: Exception) { + if (logger.isDebugEnabled) { + logger.debug("Invalid CRL from $uri (${duration}ms)", e) + } + throw e + } + logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" } + return crl + } + + fun clearCache() { + cache.invalidateAll() } override fun fetch(certificate: X509Certificate): Set { diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffAlgorithmsTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffAlgorithmsTest.kt index 21c7bc8a94..e951c587c2 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffAlgorithmsTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffAlgorithmsTest.kt @@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom import net.corda.core.utilities.Try import net.corda.core.utilities.contextLogger import net.corda.nodeapi.internal.config.CertificateStore -import net.corda.nodeapi.internal.protonwrapper.netty.init +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import org.assertj.core.api.Assertions import org.junit.Rule import org.junit.Test @@ -161,11 +162,9 @@ class TlsDiffAlgorithmsTest(private val serverAlgo: String, private val clientAl private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext { return SSLContext.getInstance("TLS").apply { - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) + val keyManagerFactory = keyManagerFactory(keyStore) val keyManagers = keyManagerFactory.keyManagers - val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + val trustMgrFactory = trustManagerFactory(trustStore) val trustManagers = trustMgrFactory.trustManagers init(keyManagers, trustManagers, newSecureRandom()) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffProtocolsTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffProtocolsTest.kt index 46b6bf381c..0bb81e5627 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffProtocolsTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffProtocolsTest.kt @@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom import net.corda.core.utilities.Try import net.corda.core.utilities.contextLogger import net.corda.nodeapi.internal.config.CertificateStore -import net.corda.nodeapi.internal.protonwrapper.netty.init +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import org.assertj.core.api.Assertions import org.junit.Ignore import org.junit.Rule @@ -18,7 +19,6 @@ import java.io.IOException import java.net.InetAddress import java.net.InetSocketAddress import javax.net.ssl.* -import javax.net.ssl.SNIHostName import kotlin.concurrent.thread import kotlin.test.assertEquals import kotlin.test.assertFalse @@ -209,11 +209,9 @@ class TlsDiffProtocolsTest(private val serverAlgo: String, private val clientAlg private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext { return SSLContext.getInstance("TLS").apply { - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) + val keyManagerFactory = keyManagerFactory(keyStore) val keyManagers = keyManagerFactory.keyManagers - val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + val trustMgrFactory = trustManagerFactory(trustStore) val trustManagers = trustMgrFactory.trustManagers init(keyManagers, trustManagers, newSecureRandom()) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelperTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelperTest.kt index 8fbe5fd9d9..12eb6d3e35 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelperTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelperTest.kt @@ -1,5 +1,6 @@ package net.corda.nodeapi.internal.protonwrapper.netty +import io.netty.util.concurrent.ImmediateExecutor import net.corda.core.crypto.SecureHash import net.corda.core.identity.CordaX500Name import net.corda.core.utilities.NetworkHostAndPort @@ -8,10 +9,9 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS +import net.corda.testing.internal.fixedCrlSource import org.junit.Test -import javax.net.ssl.KeyManagerFactory import javax.net.ssl.SNIHostName -import javax.net.ssl.TrustManagerFactory import kotlin.test.assertEquals class SSLHelperTest { @@ -20,15 +20,21 @@ class SSLHelperTest { val legalName = CordaX500Name("Test", "London", "GB") val sslConfig = configureTestSSL(legalName) - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get()) - val keyStore = sslConfig.keyStore - keyManagerFactory.init(CertificateStore.fromFile(keyStore.path, keyStore.storePassword, keyStore.entryPassword, false)) - val trustStore = sslConfig.trustStore - trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(CertificateStore.fromFile(trustStore.path, trustStore.storePassword, trustStore.entryPassword, false), RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL))) + val trustManagerFactory = trustManagerFactoryWithRevocation( + sslConfig.trustStore.get(), + RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL), + fixedCrlSource(emptySet()) + ) - val sslHandler = createClientSslHandler(NetworkHostAndPort("localhost", 1234), setOf(legalName), keyManagerFactory, trustManagerFactory) + val sslHandler = createClientSslHandler( + NetworkHostAndPort("localhost", 1234), + setOf(legalName), + keyManagerFactory, + trustManagerFactory, + ImmediateExecutor.INSTANCE + ) val legalNameHash = SecureHash.sha256(legalName.toString()).toString().take(32).toLowerCase() // These hardcoded values must not be changed, something is broken if you have to change these hardcoded values. diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSourceTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSourceTest.kt index 489c66de32..66c17e4a39 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSourceTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSourceTest.kt @@ -2,15 +2,13 @@ package net.corda.nodeapi.internal.revocation import net.corda.core.crypto.Crypto import net.corda.core.utilities.NetworkHostAndPort -import net.corda.nodeapi.internal.createDevNodeCa -import net.corda.testing.core.ALICE_NAME +import net.corda.nodeapi.internal.DEV_INTERMEDIATE_CA import net.corda.testing.node.internal.network.CrlServer import org.assertj.core.api.Assertions.assertThat import org.bouncycastle.jce.provider.BouncyCastleProvider import org.junit.After import org.junit.Before import org.junit.Test -import java.math.BigInteger class CertDistPointCrlSourceTest { private lateinit var crlServer: CrlServer @@ -39,13 +37,14 @@ class CertDistPointCrlSourceTest { assertThat(single().revokedCertificates).isNull() } - val nodeCaCert = crlServer.replaceNodeCertDistPoint(createDevNodeCa(crlServer.intermediateCa, ALICE_NAME).certificate) + crlSource.clearCache() - crlServer.revokedNodeCerts += listOf(BigInteger.ONE, BigInteger.TEN) - with(crlSource.fetch(nodeCaCert)) { // Use a different cert to avoid the cache + crlServer.revokedIntermediateCerts += DEV_INTERMEDIATE_CA.certificate + with(crlSource.fetch(crlServer.intermediateCa.certificate)) { assertThat(size).isEqualTo(1) val revokedCertificates = single().revokedCertificates - assertThat(revokedCertificates.map { it.serialNumber }).containsExactlyInAnyOrder(BigInteger.ONE, BigInteger.TEN) + // This also tests clearCache() works. + assertThat(revokedCertificates.map { it.serialNumber }).containsExactly(DEV_INTERMEDIATE_CA.certificate.serialNumber) } } } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationCheckerTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationCheckerTest.kt index 6ccf8ce594..6dbfcd4515 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationCheckerTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationCheckerTest.kt @@ -5,7 +5,7 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.crypto.X509Utilities -import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource +import net.corda.testing.internal.fixedCrlSource import org.bouncycastle.jcajce.provider.asymmetric.x509.CertificateFactory import org.junit.Test import java.math.BigInteger @@ -41,10 +41,8 @@ class CordaRevocationCheckerTest { val resourceAsStream = javaClass.getResourceAsStream("/net/corda/nodeapi/internal/protonwrapper/netty/doorman.crl") val crl = CertificateFactory().engineGenerateCRL(resourceAsStream) as X509CRL - val crlSource = object : CrlSource { - override fun fetch(certificate: X509Certificate): Set = setOf(crl) - } - val checker = CordaRevocationChecker(crlSource, + val checker = CordaRevocationChecker( + crlSource = fixedCrlSource(setOf(crl)), softFail = true, dateSource = { Date.from(date.atStartOfDay().toInstant(ZoneOffset.UTC)) } ) diff --git a/node/src/test/kotlin/net/corda/node/internal/artemis/RevocationCheckTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/RevocationTest.kt similarity index 60% rename from node/src/test/kotlin/net/corda/node/internal/artemis/RevocationCheckTest.kt rename to node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/RevocationTest.kt index 2495fb370e..675f222f92 100644 --- a/node/src/test/kotlin/net/corda/node/internal/artemis/RevocationCheckTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/RevocationTest.kt @@ -1,20 +1,16 @@ -package net.corda.node.internal.artemis +package net.corda.nodeapi.internal.revocation import net.corda.core.crypto.Crypto -import net.corda.core.utilities.days -import net.corda.node.internal.artemis.CertificateChainCheckPolicy.RevocationCheck +import net.corda.nodeapi.internal.config.CertificateStore +import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.CertificateType +import net.corda.nodeapi.internal.crypto.X509KeyStore import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation +import net.corda.testing.core.createCRL import org.bouncycastle.asn1.x500.X500Name -import org.bouncycastle.asn1.x509.CRLReason -import org.bouncycastle.asn1.x509.Extension -import org.bouncycastle.asn1.x509.ExtensionsGenerator -import org.bouncycastle.asn1.x509.GeneralName -import org.bouncycastle.asn1.x509.GeneralNames -import org.bouncycastle.asn1.x509.IssuingDistributionPoint -import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder -import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder import org.junit.Before import org.junit.Rule import org.junit.Test @@ -22,15 +18,18 @@ import org.junit.rules.TemporaryFolder import org.junit.runner.RunWith import org.junit.runners.Parameterized import java.io.File +import java.security.KeyPair import java.security.KeyStore import java.security.PrivateKey +import java.security.cert.CertificateException import java.security.cert.X509Certificate import java.util.* +import javax.net.ssl.X509TrustManager import javax.security.auth.x500.X500Principal -import kotlin.test.assertFails +import kotlin.test.assertFailsWith @RunWith(Parameterized::class) -class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { +class RevocationTest(private val revocationMode: RevocationConfig.Mode) { companion object { @JvmStatic @Parameterized.Parameters(name = "revocationMode = {0}") @@ -45,8 +44,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { private lateinit var doormanCRL: File private lateinit var tlsCRL: File - private val keyStore = KeyStore.getInstance("JKS") - private val trustStore = KeyStore.getInstance("JKS") + private lateinit var trustManager: X509TrustManager private val rootKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) private val tlsCRLIssuerKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) @@ -61,7 +59,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { private lateinit var tlsCert: X509Certificate private val chain - get() = listOf(tlsCert, nodeCACert, doormanCert, rootCert).toTypedArray() + get() = arrayOf(tlsCert, nodeCACert, doormanCert, rootCert) @Before fun before() { @@ -72,10 +70,18 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { rootCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=root"), rootKeyPair) tlsCRLIssuerCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=issuer"), tlsCRLIssuerKeyPair) + val trustStore = KeyStore.getInstance("JKS") trustStore.load(null, null) trustStore.setCertificateEntry("cordatlscrlsigner", tlsCRLIssuerCert) trustStore.setCertificateEntry("cordarootca", rootCert) + val trustManagerFactory = trustManagerFactoryWithRevocation( + CertificateStore.of(X509KeyStore(trustStore, "pass"), "pass", "pass"), + RevocationConfigImpl(revocationMode), + CertDistPointCrlSource() + ) + trustManager = trustManagerFactory.trustManagers.single() as X509TrustManager + doormanCert = X509Utilities.createCertificate( CertificateType.INTERMEDIATE_CA, rootCert, rootKeyPair, X500Principal("CN=doorman"), doormanKeyPair.public, crlDistPoint = rootCRL.toURI().toString() @@ -89,43 +95,34 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) ) - rootCRL.createCRL(rootCert, rootKeyPair.private, false) - doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false) - tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true) + rootCRL.writeCRL(rootCert, rootKeyPair.private, false) + doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false) + tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true) } - private fun File.createCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) { - val builder = JcaX509v2CRLBuilder(certificate.subjectX500Principal, Date()) - builder.setNextUpdate(Date.from(Date().toInstant() + 7.days)) - builder.addExtension(Extension.issuingDistributionPoint, true, IssuingDistributionPoint(null, indirect, false)) - revoked.forEach { - val extensionsGenerator = ExtensionsGenerator() - extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(CRLReason.keyCompromise)) - // Certificate issuer is required for indirect CRL - val certificateIssuerName = X500Name.getInstance(it.issuerX500Principal.encoded) - extensionsGenerator.addExtension(Extension.certificateIssuer, true, GeneralNames(GeneralName(certificateIssuerName))) - builder.addCRLEntry(it.serialNumber, Date(), extensionsGenerator.generate()) - } - val holder = builder.build(JcaContentSignerBuilder("SHA256withECDSA").setProvider(Crypto.findProvider("BC")).build(privateKey)) - outputStream().use { it.write(holder.encoded) } + private fun File.writeCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) { + val crl = createCRL( + CertificateAndKeyPair(certificate, KeyPair(certificate.publicKey, privateKey)), + revoked.asList(), + indirect = indirect + ) + writeBytes(crl.encoded) } - private fun assertFailsFor(vararg modes: RevocationConfig.Mode, block: () -> Unit) { - if (revocationMode in modes) assertFails(block) else block() + private fun assertFailsFor(vararg modes: RevocationConfig.Mode) { + if (revocationMode in modes) assertFailsWith(CertificateException::class, ::doRevocationCheck) else doRevocationCheck() } @Test(timeout = 300_000) fun `ok with empty CRLs`() { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) + doRevocationCheck() } @Test(timeout = 300_000) fun `soft fail with revoked TLS certificate`() { - tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert) + tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert) - assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -136,9 +133,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) ) - assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -148,9 +143,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name("CN=unknown") ) - assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -160,9 +153,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlDistPoint = tlsCRL.toURI().toString() ) - assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -172,18 +163,16 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=other"), otherKeyPair.public, crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) ) - tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert) + tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert) - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) + doRevocationCheck() } @Test(timeout = 300_000) fun `soft fail with revoked node CA certificate`() { - doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, nodeCACert) + doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, nodeCACert) - assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -193,9 +182,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/doorman" ) - assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -205,8 +192,12 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=other"), otherKeyPair.public, crlDistPoint = doormanCRL.toURI().toString() ) - doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, otherCert) + doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, otherCert) - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) + doRevocationCheck() + } + + private fun doRevocationCheck() { + trustManager.checkClientTrusted(chain, "ECDHE_ECDSA") } } diff --git a/node/build.gradle b/node/build.gradle index 05dc5e58a5..a766182f87 100644 --- a/node/build.gradle +++ b/node/build.gradle @@ -273,8 +273,6 @@ tasks.register('integrationTest', Test) { testClassesDirs = sourceSets.integrationTest.output.classesDirs classpath = sourceSets.integrationTest.runtimeClasspath maxParallelForks = (System.env.CORDA_NODE_INT_TESTING_FORKS == null) ? 1 : "$System.env.CORDA_NODE_INT_TESTING_FORKS".toInteger() - // CertificateRevocationListNodeTests - systemProperty 'net.corda.dpcrl.connect.timeout', '4000' } tasks.register('slowIntegrationTest', Test) { diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt index f3fb153013..62f3168e0b 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt @@ -15,12 +15,15 @@ import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionResult -import net.corda.nodeapi.internal.protonwrapper.netty.init -import net.corda.nodeapi.internal.protonwrapper.netty.initialiseTrustStoreAndEnableCrlChecking +import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.internal.fixedCrlSource import org.junit.Assume.assumeFalse import org.junit.Before import org.junit.Rule @@ -98,11 +101,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) { override val maxMessageSize: Int = MAX_MESSAGE_SIZE } - serverKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - serverTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + serverKeyManagerFactory = keyManagerFactory(keyStore) - serverKeyManagerFactory.init(keyStore) - serverTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(serverAmqpConfig.trustStore, serverAmqpConfig.revocationConfig)) + serverTrustManagerFactory = trustManagerFactoryWithRevocation( + serverAmqpConfig.trustStore, + RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL), + fixedCrlSource(emptySet()) + ) } private fun setupClientCertificates() { @@ -129,11 +134,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) { override val sslHandshakeTimeout: Duration = 3.seconds } - clientKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - clientTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + clientKeyManagerFactory = keyManagerFactory(keyStore) - clientKeyManagerFactory.init(keyStore) - clientTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(clientAmqpConfig.trustStore, clientAmqpConfig.revocationConfig)) + clientTrustManagerFactory = trustManagerFactoryWithRevocation( + clientAmqpConfig.trustStore, + RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL), + fixedCrlSource(emptySet()) + ) } @Test(timeout = 300_000) diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt index 788874f436..d7649fcfef 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt @@ -1,3 +1,5 @@ +@file:Suppress("LongParameterList") + package net.corda.node.amqp import com.nhaarman.mockito_kotlin.doReturn @@ -5,10 +7,10 @@ import com.nhaarman.mockito_kotlin.whenever import net.corda.core.crypto.Crypto import net.corda.core.identity.CordaX500Name import net.corda.core.internal.div -import net.corda.core.internal.rootCause import net.corda.core.internal.times -import net.corda.core.toFuture import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.minutes +import net.corda.core.utilities.seconds import net.corda.coretesting.internal.rigorousMock import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.node.services.config.NodeConfiguration @@ -18,64 +20,68 @@ import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.config.CertificateStoreSupplier import net.corda.nodeapi.internal.config.MutualSslConfiguration -import net.corda.nodeapi.internal.crypto.X509Utilities +import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_CA +import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer +import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionChange import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig +import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.CHARLIE_NAME import net.corda.testing.core.MAX_MESSAGE_SIZE import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.node.internal.network.CrlServer import net.corda.testing.node.internal.network.CrlServer.Companion.EMPTY_CRL -import net.corda.testing.node.internal.network.CrlServer.Companion.FORBIDDEN_CRL import net.corda.testing.node.internal.network.CrlServer.Companion.NODE_CRL import net.corda.testing.node.internal.network.CrlServer.Companion.withCrlDistPoint import org.apache.activemq.artemis.api.core.QueueConfiguration import org.apache.activemq.artemis.api.core.RoutingType import org.assertj.core.api.Assertions.assertThat -import org.assertj.core.api.Assertions.assertThatIllegalArgumentException import org.bouncycastle.jce.provider.BouncyCastleProvider import org.junit.After import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder -import java.net.SocketTimeoutException +import java.io.Closeable import java.security.cert.X509Certificate import java.time.Duration +import java.util.concurrent.BlockingQueue +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger -import kotlin.test.assertEquals +import java.util.stream.IntStream -@Suppress("LongParameterList") -class CertificateRevocationListNodeTests { +abstract class AbstractServerRevocationTest { @Rule @JvmField val temporaryFolder = TemporaryFolder() private val portAllocation = incrementalPortAllocation() - private val serverPort = portAllocation.nextPort() + protected val serverPort = portAllocation.nextPort() - private lateinit var crlServer: CrlServer - private lateinit var amqpServer: AMQPServer - private lateinit var amqpClient: AMQPClient + protected lateinit var crlServer: CrlServer + private val amqpClients = ArrayList() - private abstract class AbstractNodeConfiguration : NodeConfiguration + protected lateinit var defaultCrlDistPoints: CrlDistPoints + + protected abstract class AbstractNodeConfiguration : NodeConfiguration companion object { private val unreachableIpCounter = AtomicInteger(1) - private val crlConnectTimeout = Duration.ofMillis(System.getProperty("net.corda.dpcrl.connect.timeout").toLong()) + val crlConnectTimeout = 2.seconds /** * Use this method to get a unqiue unreachable IP address. Subsequent uses of the same IP for connection timeout testing purposes * may not work as the OS process may cache the timeout result. */ - private fun newUnreachableIpAddress(): String { + private fun newUnreachableIpAddress(): NetworkHostAndPort { check(unreachableIpCounter.get() != 255) - return "10.255.255.${unreachableIpCounter.getAndIncrement()}" + return NetworkHostAndPort("10.255.255", unreachableIpCounter.getAndIncrement()) } } @@ -85,252 +91,190 @@ class CertificateRevocationListNodeTests { Crypto.findProvider(BouncyCastleProvider.PROVIDER_NAME) crlServer = CrlServer(NetworkHostAndPort("localhost", 0)) crlServer.start() + defaultCrlDistPoints = CrlDistPoints(crlServer.hostAndPort) } @After fun tearDown() { - if (::amqpClient.isInitialized) { - amqpClient.close() - } - if (::amqpServer.isInitialized) { - amqpServer.close() - } + amqpClients.parallelStream().forEach(AMQPClient::close) if (::crlServer.isInitialized) { crlServer.close() } } @Test(timeout=300_000) - fun `AMQP server connection works and soft fail is enabled`() { - verifyAMQPConnection( + fun `connection succeeds when soft fail is enabled`() { + verifyConnection( crlCheckSoftFail = true, - expectedConnectStatus = true + expectedConnectedStatus = true ) } @Test(timeout=300_000) - fun `AMQP server connection works and soft fail is disabled`() { - verifyAMQPConnection( + fun `connection succeeds when soft fail is disabled`() { + verifyConnection( crlCheckSoftFail = false, - expectedConnectStatus = true + expectedConnectedStatus = true ) } @Test(timeout=300_000) - fun `AMQP server connection fails when client's certificate is revoked and soft fail is enabled`() { - verifyAMQPConnection( + fun `connection fails when client's certificate is revoked and soft fail is enabled`() { + verifyConnection( crlCheckSoftFail = true, revokeClientCert = true, - expectedConnectStatus = false + expectedConnectedStatus = false ) } @Test(timeout=300_000) - fun `AMQP server connection fails when client's certificate is revoked and soft fail is disabled`() { - verifyAMQPConnection( + fun `connection fails when client's certificate is revoked and soft fail is disabled`() { + verifyConnection( crlCheckSoftFail = false, revokeClientCert = true, - expectedConnectStatus = false + expectedConnectedStatus = false ) } @Test(timeout=300_000) - fun `AMQP server connection fails when servers's certificate is revoked and soft fail is enabled`() { - verifyAMQPConnection( + fun `connection fails when server's certificate is revoked and soft fail is enabled`() { + verifyConnection( crlCheckSoftFail = true, revokeServerCert = true, - expectedConnectStatus = false + expectedConnectedStatus = false ) } @Test(timeout=300_000) - fun `AMQP server connection fails when servers's certificate is revoked and soft fail is disabled`() { - verifyAMQPConnection( + fun `connection fails when server's certificate is revoked and soft fail is disabled`() { + verifyConnection( crlCheckSoftFail = false, revokeServerCert = true, - expectedConnectStatus = false + expectedConnectedStatus = false ) } @Test(timeout=300_000) - fun `AMQP server connection succeeds when CRL cannot be obtained and soft fail is enabled`() { - verifyAMQPConnection( + fun `connection succeeds when CRL cannot be obtained and soft fail is enabled`() { + verifyConnection( crlCheckSoftFail = true, - nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/invalid.crl", - expectedConnectStatus = true + clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"), + expectedConnectedStatus = true ) } @Test(timeout=300_000) - fun `AMQP server connection fails when CRL cannot be obtained and soft fail is disabled`() { - verifyAMQPConnection( + fun `connection fails when CRL cannot be obtained and soft fail is disabled`() { + verifyConnection( crlCheckSoftFail = false, - nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/invalid.crl", - expectedConnectStatus = false + clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"), + expectedConnectedStatus = false ) } @Test(timeout=300_000) - fun `AMQP server connection succeeds when CRL is not defined and soft fail is enabled`() { - verifyAMQPConnection( + fun `connection succeeds when CRL is not defined for node CA cert and soft fail is enabled`() { + verifyConnection( crlCheckSoftFail = true, - nodeCrlDistPoint = null, - expectedConnectStatus = true + clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null), + expectedConnectedStatus = true ) } @Test(timeout=300_000) - fun `AMQP server connection fails when CRL is not defined and soft fail is disabled`() { - verifyAMQPConnection( + fun `connection fails when CRL is not defined for node CA cert and soft fail is disabled`() { + verifyConnection( crlCheckSoftFail = false, - nodeCrlDistPoint = null, - expectedConnectStatus = false + clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null), + expectedConnectedStatus = false ) } @Test(timeout=300_000) - fun `AMQP server connection succeeds when CRL retrieval is forbidden and soft fail is enabled`() { - verifyAMQPConnection( + fun `connection succeeds when CRL is not defined for TLS cert and soft fail is enabled`() { + verifyConnection( crlCheckSoftFail = true, - nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL", - expectedConnectStatus = true + clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null), + expectedConnectedStatus = true ) } @Test(timeout=300_000) - fun `AMQP server connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() { - verifyAMQPConnection( - crlCheckSoftFail = true, - nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl", - sslHandshakeTimeout = crlConnectTimeout * 3, - expectedConnectStatus = true + fun `connection fails when CRL is not defined for TLS cert and soft fail is disabled`() { + verifyConnection( + crlCheckSoftFail = false, + clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null), + expectedConnectedStatus = false ) - val timeoutExceptions = (amqpServer.softFailExceptions + amqpClient.softFailExceptions) - .map { it.rootCause } - .filterIsInstance() - assertThat(timeoutExceptions).isNotEmpty } @Test(timeout=300_000) - fun `AMQP server connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() { - verifyAMQPConnection( + fun `connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() { + verifyConnection( + crlCheckSoftFail = true, + sslHandshakeTimeout = crlConnectTimeout * 4, + clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()), + expectedConnectedStatus = true + ) + } + + @Test(timeout=300_000) + fun `connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() { + verifyConnection( crlCheckSoftFail = true, - nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl", sslHandshakeTimeout = crlConnectTimeout / 2, - expectedConnectStatus = false + clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()), + expectedConnectedStatus = false ) } - @Test(timeout=300_000) - fun `verify CRL algorithms`() { - val crl = crlServer.createRevocationList( - "SHA256withECDSA", - crlServer.rootCa, - EMPTY_CRL, - true, - emptyList() + @Test(timeout = 300_000) + fun `influx of new clients during CRL endpoint downtime does not cause existing connections to drop`() { + val serverCrlSource = CertDistPointCrlSource() + // Start the server and verify the first client has connected + val firstClientConnectionChangeStatus = verifyConnection( + crlCheckSoftFail = true, + crlSource = serverCrlSource, + // In general, N remoting threads will naturally support N-1 new handshaking clients plus one thread for heartbeating with + // existing clients. The trick is to make sure at least N new clients are also supported. + remotingThreads = 2, + expectedConnectedStatus = true ) - // This should pass. - crl.verify(crlServer.rootCa.keyPair.public) - // Try changing the algorithm to EC will fail. - assertThatIllegalArgumentException().isThrownBy { - crlServer.createRevocationList( - "EC", - crlServer.rootCa, - EMPTY_CRL, - true, - emptyList() + // Now simulate the CRL endpoint becoming very slow/unreachable + crlServer.delay = 10.minutes + // And pretend enough time has elapsed that the cached CRLs have expired and need downloading again + serverCrlSource.clearCache() + + // Now a bunch of new clients have arrived and want to handshake with the server, which will potentially cause the server's Netty + // threads to be tied up in trying to download the CRLs. + IntStream.range(0, 2).parallel().forEach { clientIndex -> + val (newClient, _) = createAMQPClient( + serverPort, + crlCheckSoftFail = true, + legalName = CordaX500Name("NewClient$clientIndex", "London", "GB"), + crlDistPoints = defaultCrlDistPoints ) - }.withMessage("Unknown signature type requested: EC") + newClient.start() + } + + // Make sure there are no further connection change updates, i.e. the first client stays connected throughout this whole saga + assertThat(firstClientConnectionChangeStatus.poll(30, TimeUnit.SECONDS)).isNull() } - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with soft fail CRL check`() { - verifyArtemisConnection( - crlCheckSoftFail = true, - crlCheckArtemisServer = true, - expectedStatus = MessageStatus.Acknowledged - ) - } + protected abstract fun verifyConnection(crlCheckSoftFail: Boolean, + crlSource: CertDistPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout), + sslHandshakeTimeout: Duration? = null, + remotingThreads: Int? = null, + clientCrlDistPoints: CrlDistPoints = defaultCrlDistPoints, + revokeClientCert: Boolean = false, + revokeServerCert: Boolean = false, + expectedConnectedStatus: Boolean): BlockingQueue - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with hard fail CRL check`() { - verifyArtemisConnection( - crlCheckSoftFail = false, - crlCheckArtemisServer = true, - expectedStatus = MessageStatus.Acknowledged - ) - } - - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with soft fail CRL check on unavailable URL`() { - verifyArtemisConnection( - crlCheckSoftFail = true, - crlCheckArtemisServer = true, - expectedStatus = MessageStatus.Acknowledged, - nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL" - ) - } - - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with soft fail CRL check on unreachable URL if CRL timeout is within SSL handshake timeout`() { - verifyArtemisConnection( - crlCheckSoftFail = true, - crlCheckArtemisServer = true, - expectedStatus = MessageStatus.Acknowledged, - nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl", - sslHandshakeTimeout = crlConnectTimeout * 3 - ) - } - - @Test(timeout = 300_000) - fun `Artemis server connection fails with soft fail CRL check on unreachable URL if CRL timeout is not within SSL handshake timeout`() { - verifyArtemisConnection( - crlCheckSoftFail = true, - crlCheckArtemisServer = true, - expectedConnected = false, - nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl", - sslHandshakeTimeout = crlConnectTimeout / 2 - ) - } - - @Test(timeout = 300_000) - fun `Artemis server connection fails with hard fail CRL check on unavailable URL`() { - verifyArtemisConnection( - crlCheckSoftFail = false, - crlCheckArtemisServer = true, - expectedStatus = MessageStatus.Rejected, - nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL" - ) - } - - @Test(timeout = 300_000) - fun `Artemis server connection fails with soft fail CRL check on revoked node certificate`() { - verifyArtemisConnection( - crlCheckSoftFail = true, - crlCheckArtemisServer = true, - expectedStatus = MessageStatus.Rejected, - revokedNodeCert = true - ) - } - - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with disabled CRL check on revoked node certificate`() { - verifyArtemisConnection( - crlCheckSoftFail = false, - crlCheckArtemisServer = false, - expectedStatus = MessageStatus.Acknowledged, - revokedNodeCert = true - ) - } - - private fun createAMQPClient(targetPort: Int, - crlCheckSoftFail: Boolean, - legalName: CordaX500Name, - nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL", - tlsCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$EMPTY_CRL", - maxMessageSize: Int = MAX_MESSAGE_SIZE): X509Certificate { + protected fun createAMQPClient(targetPort: Int, + crlCheckSoftFail: Boolean, + legalName: CordaX500Name, + crlDistPoints: CrlDistPoints): Pair { val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val certificatesDirectory = baseDirectory / "certificates" val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) @@ -344,31 +288,128 @@ class CertificateRevocationListNodeTests { doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail } clientConfig.configureWithDevSSLCertificate() - val nodeCert = recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, tlsCrlDistPoint) + val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer) val keyStore = clientConfig.p2pSslOptions.keyStore.get() val amqpConfig = object : AMQPConfiguration { override val keyStore = keyStore override val trustStore = clientConfig.p2pSslOptions.trustStore.get() - override val maxMessageSize: Int = maxMessageSize + override val maxMessageSize: Int = MAX_MESSAGE_SIZE + override val trace: Boolean = true } - amqpClient = AMQPClient( + val amqpClient = AMQPClient( listOf(NetworkHostAndPort("localhost", targetPort)), setOf(CHARLIE_NAME), amqpConfig, - threadPoolName = legalName.organisation + nettyThreading = AMQPClient.NettyThreading.NonShared(legalName.organisation), + distPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout) ) + amqpClients += amqpClient + return Pair(amqpClient, nodeCert) + } - return nodeCert + protected fun AMQPClient.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus: Boolean): BlockingQueue { + val connectionChangeStatus = LinkedBlockingQueue() + onConnection.subscribe { connectionChangeStatus.add(it) } + start() + assertThat(connectionChangeStatus.take().connected).isEqualTo(expectedConnectedStatus) + return connectionChangeStatus + } + + protected data class CrlDistPoints(val crlServerAddress: NetworkHostAndPort, + val nodeCa: String? = NODE_CRL, + val tls: String? = EMPTY_CRL) { + private val nodeCaCertCrlDistPoint: String? get() = nodeCa?.let { "http://$crlServerAddress/crl/$it" } + private val tlsCertCrlDistPoint: String? get() = tls?.let { "http://$crlServerAddress/crl/$it" } + + fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier, + p2pSslConfiguration: MutualSslConfiguration, + crlServer: CrlServer): X509Certificate { + val nodeKeyStore = signingCertificateStore.get() + val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_CA, nodeKeyStore.entryPassword) } + val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCertCrlDistPoint) + val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) + + nodeKeyStore.query { getCertificateChain(CORDA_CLIENT_CA) }.drop(2) + + nodeKeyStore.update { + internal.deleteEntry(CORDA_CLIENT_CA) + } + nodeKeyStore.update { + setPrivateKey(CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword) + } + + val sslKeyStore = p2pSslConfiguration.keyStore.get() + val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_TLS, sslKeyStore.entryPassword) } + val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCertCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal) + val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) + + sslKeyStore.query { getCertificateChain(CORDA_CLIENT_TLS) }.drop(3) + + sslKeyStore.update { + internal.deleteEntry(CORDA_CLIENT_TLS) + } + sslKeyStore.update { + setPrivateKey(CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword) + } + return newNodeCert + } + } +} + + +class AMQPServerRevocationTest : AbstractServerRevocationTest() { + private lateinit var amqpServer: AMQPServer + + @After + fun shutDown() { + if (::amqpServer.isInitialized) { + amqpServer.close() + } + } + + override fun verifyConnection(crlCheckSoftFail: Boolean, + crlSource: CertDistPointCrlSource, + sslHandshakeTimeout: Duration?, + remotingThreads: Int?, + clientCrlDistPoints: CrlDistPoints, + revokeClientCert: Boolean, + revokeServerCert: Boolean, + expectedConnectedStatus: Boolean): BlockingQueue { + val serverCert = createAMQPServer( + serverPort, + CHARLIE_NAME, + crlCheckSoftFail, + defaultCrlDistPoints, + crlSource, + sslHandshakeTimeout, + remotingThreads + ) + if (revokeServerCert) { + crlServer.revokedNodeCerts.add(serverCert) + } + amqpServer.start() + amqpServer.onReceive.subscribe { + it.complete(true) + } + val (client, clientCert) = createAMQPClient( + serverPort, + crlCheckSoftFail = crlCheckSoftFail, + legalName = ALICE_NAME, + crlDistPoints = clientCrlDistPoints + ) + if (revokeClientCert) { + crlServer.revokedNodeCerts.add(clientCert) + } + + return client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus) } private fun createAMQPServer(port: Int, legalName: CordaX500Name, crlCheckSoftFail: Boolean, - nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL", - tlsCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$EMPTY_CRL", - maxMessageSize: Int = MAX_MESSAGE_SIZE, - sslHandshakeTimeout: Duration? = null): X509Certificate { + crlDistPoints: CrlDistPoints, + distPointCrlSource: CertDistPointCrlSource, + sslHandshakeTimeout: Duration?, + remotingThreads: Int?): X509Certificate { check(!::amqpServer.isInitialized) val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val certificatesDirectory = baseDirectory / "certificates" @@ -382,92 +423,103 @@ class CertificateRevocationListNodeTests { doReturn(signingCertificateStore).whenever(it).signingCertificateStore } serverConfig.configureWithDevSSLCertificate() - val nodeCert = recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, tlsCrlDistPoint) + val serverCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer) val keyStore = serverConfig.p2pSslOptions.keyStore.get() val amqpConfig = object : AMQPConfiguration { override val keyStore = keyStore override val trustStore = serverConfig.p2pSslOptions.trustStore.get() override val revocationConfig = crlCheckSoftFail.toRevocationConfig() - override val maxMessageSize: Int = maxMessageSize + override val maxMessageSize: Int = MAX_MESSAGE_SIZE override val sslHandshakeTimeout: Duration = sslHandshakeTimeout ?: super.sslHandshakeTimeout } - amqpServer = AMQPServer("0.0.0.0", port, amqpConfig, threadPoolName = legalName.organisation) - return nodeCert - } - - private fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier, - p2pSslConfiguration: MutualSslConfiguration, - nodeCaCrlDistPoint: String?, - tlsCrlDistPoint: String?): X509Certificate { - val nodeKeyStore = signingCertificateStore.get() - val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA, nodeKeyStore.entryPassword) } - val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCrlDistPoint) - val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) + - nodeKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_CA) }.drop(2) - - nodeKeyStore.update { - internal.deleteEntry(X509Utilities.CORDA_CLIENT_CA) - } - nodeKeyStore.update { - setPrivateKey(X509Utilities.CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword) - } - - val sslKeyStore = p2pSslConfiguration.keyStore.get() - val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_TLS, sslKeyStore.entryPassword) } - val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal) - val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) + - sslKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_TLS) }.drop(3) - - sslKeyStore.update { - internal.deleteEntry(X509Utilities.CORDA_CLIENT_TLS) - } - sslKeyStore.update { - setPrivateKey(X509Utilities.CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword) - } - return newNodeCert - } - - private fun verifyAMQPConnection(crlCheckSoftFail: Boolean, - nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL", - revokeServerCert: Boolean = false, - revokeClientCert: Boolean = false, - sslHandshakeTimeout: Duration? = null, - expectedConnectStatus: Boolean) { - val serverCert = createAMQPServer( - serverPort, - CHARLIE_NAME, - crlCheckSoftFail = crlCheckSoftFail, - nodeCrlDistPoint = nodeCrlDistPoint, - sslHandshakeTimeout = sslHandshakeTimeout + amqpServer = AMQPServer( + "0.0.0.0", + port, + amqpConfig, + threadPoolName = legalName.organisation, + distPointCrlSource = distPointCrlSource, + remotingThreads = remotingThreads ) - if (revokeServerCert) { - crlServer.revokedNodeCerts.add(serverCert.serialNumber) + return serverCert + } +} + + +class ArtemisServerRevocationTest : AbstractServerRevocationTest() { + private lateinit var artemisNode: ArtemisNode + private var crlCheckArtemisServer = true + + @After + fun shutDown() { + if (::artemisNode.isInitialized) { + artemisNode.close() } - amqpServer.start() - amqpServer.onReceive.subscribe { - it.complete(true) - } - val clientCert = createAMQPClient( + } + + @Test(timeout = 300_000) + fun `connection succeeds with disabled CRL check on revoked node certificate`() { + crlCheckArtemisServer = false + verifyConnection( + crlCheckSoftFail = false, + revokeClientCert = true, + expectedConnectedStatus = true + ) + } + + override fun verifyConnection(crlCheckSoftFail: Boolean, + crlSource: CertDistPointCrlSource, + sslHandshakeTimeout: Duration?, + remotingThreads: Int?, + clientCrlDistPoints: CrlDistPoints, + revokeClientCert: Boolean, + revokeServerCert: Boolean, + expectedConnectedStatus: Boolean): BlockingQueue { + val (client, clientCert) = createAMQPClient( serverPort, - crlCheckSoftFail = crlCheckSoftFail, + crlCheckSoftFail = true, legalName = ALICE_NAME, - nodeCrlDistPoint = nodeCrlDistPoint + crlDistPoints = clientCrlDistPoints ) if (revokeClientCert) { - crlServer.revokedNodeCerts.add(clientCert.serialNumber) + crlServer.revokedNodeCerts.add(clientCert) } - val serverConnected = amqpServer.onConnection.toFuture() - amqpClient.start() - val serverConnect = serverConnected.get() - assertThat(serverConnect.connected).isEqualTo(expectedConnectStatus) + + val nodeCert = startArtemisNode( + CHARLIE_NAME, + crlCheckSoftFail, + defaultCrlDistPoints, + crlSource, + sslHandshakeTimeout, + remotingThreads + ) + if (revokeServerCert) { + crlServer.revokedNodeCerts.add(nodeCert) + } + + val queueName = "${P2P_PREFIX}Test" + artemisNode.client.started!!.session.createQueue( + QueueConfiguration(queueName).setRoutingType(RoutingType.ANYCAST).setAddress(queueName).setDurable(true) + ) + + val clientConnectionChangeStatus = client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus) + + if (expectedConnectedStatus) { + val msg = client.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap()) + client.write(msg) + assertThat(msg.onComplete.get()).isEqualTo(MessageStatus.Acknowledged) + } + + return clientConnectionChangeStatus } - private fun createArtemisServerAndClient(legalName: CordaX500Name, - crlCheckSoftFail: Boolean, - crlCheckArtemisServer: Boolean, - nodeCrlDistPoint: String, - sslHandshakeTimeout: Duration?): Pair { - val baseDirectory = temporaryFolder.root.toPath() / "artemis" + private fun startArtemisNode(legalName: CordaX500Name, + crlCheckSoftFail: Boolean, + crlDistPoints: CrlDistPoints, + distPointCrlSource: CertDistPointCrlSource, + sslHandshakeTimeout: Duration?, + remotingThreads: Int?): X509Certificate { + check(!::artemisNode.isInitialized) + val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val certificatesDirectory = baseDirectory / "certificates" val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, sslHandshakeTimeout = sslHandshakeTimeout) @@ -483,62 +535,34 @@ class CertificateRevocationListNodeTests { doReturn(crlCheckArtemisServer).whenever(it).crlCheckArtemisServer } artemisConfig.configureWithDevSSLCertificate() - recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, null) + val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer) val server = ArtemisMessagingServer( artemisConfig, artemisConfig.p2pAddress, MAX_MESSAGE_SIZE, threadPoolName = "${legalName.organisation}-server", - trace = true + trace = true, + distPointCrlSource = distPointCrlSource, + remotingThreads = remotingThreads ) val client = ArtemisMessagingClient( artemisConfig.p2pSslOptions, artemisConfig.p2pAddress, MAX_MESSAGE_SIZE, - threadPoolName = "${legalName.organisation}-client", - trace = true + threadPoolName = "${legalName.organisation}-client" ) server.start() client.start() - return server to client + val artemisNode = ArtemisNode(server, client) + this.artemisNode = artemisNode + return nodeCert } - private fun verifyArtemisConnection(crlCheckSoftFail: Boolean, - crlCheckArtemisServer: Boolean, - expectedConnected: Boolean = true, - expectedStatus: MessageStatus? = null, - revokedNodeCert: Boolean = false, - nodeCrlDistPoint: String = "http://${crlServer.hostAndPort}/crl/$NODE_CRL", - sslHandshakeTimeout: Duration? = null) { - val queueName = P2P_PREFIX + "Test" - val (artemisServer, artemisClient) = createArtemisServerAndClient( - CHARLIE_NAME, - crlCheckSoftFail, - crlCheckArtemisServer, - nodeCrlDistPoint, - sslHandshakeTimeout - ) - artemisServer.use { - artemisClient.started!!.session.createQueue( - QueueConfiguration(queueName).setRoutingType(RoutingType.ANYCAST).setAddress(queueName).setDurable(true) - ) - - val nodeCert = createAMQPClient(serverPort, true, ALICE_NAME, nodeCrlDistPoint) - if (revokedNodeCert) { - crlServer.revokedNodeCerts.add(nodeCert.serialNumber) - } - val clientConnected = amqpClient.onConnection.toFuture() - amqpClient.start() - val clientConnect = clientConnected.get() - assertThat(clientConnect.connected).isEqualTo(expectedConnected) - - if (expectedConnected) { - val msg = amqpClient.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap()) - amqpClient.write(msg) - assertEquals(expectedStatus, msg.onComplete.get()) - } - artemisClient.stop() + private class ArtemisNode(val server: ArtemisMessagingServer, val client: ArtemisMessagingClient) : Closeable { + override fun close() { + client.stop() + server.close() } } } diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt index 6be8cc1002..97f77fc5e2 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt @@ -4,12 +4,15 @@ import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.whenever import io.netty.channel.EventLoopGroup import io.netty.channel.nio.NioEventLoopGroup +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.core.crypto.newSecureRandom import net.corda.core.identity.CordaX500Name import net.corda.core.internal.div import net.corda.core.toFuture import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.contextLogger +import net.corda.coretesting.internal.rigorousMock +import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.node.services.messaging.ArtemisMessagingServer @@ -24,6 +27,9 @@ import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer import net.corda.nodeapi.internal.protonwrapper.netty.init +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME @@ -31,9 +37,6 @@ import net.corda.testing.core.CHARLIE_NAME import net.corda.testing.core.MAX_MESSAGE_SIZE import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.internal.createDevIntermediateCaCertPath -import net.corda.coretesting.internal.rigorousMock -import net.corda.coretesting.internal.stubs.CertificateStoreStubs -import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig import org.apache.activemq.artemis.api.core.QueueConfiguration import org.apache.activemq.artemis.api.core.RoutingType import org.assertj.core.api.Assertions @@ -44,7 +47,14 @@ import org.junit.Test import org.junit.rules.TemporaryFolder import java.security.cert.X509Certificate import java.util.concurrent.TimeUnit -import javax.net.ssl.* +import javax.net.ssl.KeyManagerFactory +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLException +import javax.net.ssl.SSLHandshakeException +import javax.net.ssl.SSLParameters +import javax.net.ssl.SSLServerSocket +import javax.net.ssl.SSLSocket +import javax.net.ssl.TrustManagerFactory import kotlin.concurrent.thread import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -146,15 +156,10 @@ class ProtonWrapperTests { sslConfig.keyStore.get(true).also { it.registerDevP2pCertificates(ALICE_NAME, rootCa.certificate, intermediateCa) } sslConfig.createTrustStore(rootCa.certificate) - val keyStore = sslConfig.keyStore.get() - val trustStore = sslConfig.trustStore.get() - val context = SSLContext.getInstance("TLS") - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) + val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get()) val keyManagers = keyManagerFactory.keyManagers - val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get()) val trustManagers = trustMgrFactory.trustManagers context.init(keyManagers, trustManagers, newSecureRandom()) @@ -236,7 +241,7 @@ class ProtonWrapperTests { keyManagerFactory.init(keyStore) val keyManagers = keyManagerFactory.keyManagers val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + trustMgrFactory.init(trustStore.value.internal) val trustManagers = trustMgrFactory.trustManagers context.init(keyManagers, trustManagers, newSecureRandom()) @@ -442,7 +447,7 @@ class ProtonWrapperTests { amqpServer.use { val connectionEvents = amqpServer.onConnection.toBlocking().iterator amqpServer.start() - val sharedThreads = NioEventLoopGroup() + val sharedThreads = NioEventLoopGroup(DefaultThreadFactory("sharedThreads")) val amqpClient1 = createSharedThreadsClient(sharedThreads, 0) val amqpClient2 = createSharedThreadsClient(sharedThreads, 1) amqpClient1.start() @@ -608,7 +613,7 @@ class ProtonWrapperTests { listOf(NetworkHostAndPort("localhost", serverPort)), setOf(ALICE_NAME), amqpConfig, - sharedThreadPool = sharedEventGroup) + nettyThreading = AMQPClient.NettyThreading.Shared(sharedEventGroup)) } private fun createServer(port: Int, diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/SimpleAMQPClient.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/SimpleAMQPClient.kt index bb3c86e9de..da3e831bda 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/SimpleAMQPClient.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/SimpleAMQPClient.kt @@ -3,6 +3,8 @@ package net.corda.services.messaging import net.corda.core.internal.concurrent.openFuture import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.internal.config.MutualSslConfiguration +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import org.apache.qpid.jms.JmsConnectionFactory import org.apache.qpid.jms.meta.JmsConnectionInfo import org.apache.qpid.jms.provider.Provider @@ -24,9 +26,7 @@ import javax.jms.Connection import javax.jms.Message import javax.jms.MessageProducer import javax.jms.Session -import javax.net.ssl.KeyManagerFactory import javax.net.ssl.SSLContext -import javax.net.ssl.TrustManagerFactory /** * Simple AMQP client connecting to broker using JMS. @@ -59,12 +59,8 @@ class SimpleAMQPClient(private val target: NetworkHostAndPort, private val confi private lateinit var connection: Connection private fun sslContext(): SSLContext { - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()).apply { - init(config.keyStore.get().value.internal, config.keyStore.entryPassword.toCharArray()) - } - val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply { - init(config.trustStore.get().value.internal) - } + val keyManagerFactory = keyManagerFactory(config.keyStore.get()) + val trustManagerFactory = trustManagerFactory(config.trustStore.get()) val sslContext = SSLContext.getInstance("TLS") val keyManagers = keyManagerFactory.keyManagers val trustManagers = trustManagerFactory.trustManagers diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 1c12dc61b3..198b158d24 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -5,8 +5,8 @@ import com.codahale.metrics.Gauge import com.codahale.metrics.MetricRegistry import com.google.common.collect.MutableClassToInstanceMap import com.google.common.util.concurrent.MoreExecutors -import com.google.common.util.concurrent.ThreadFactoryBuilder import com.zaxxer.hikari.pool.HikariPool +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.common.logging.errorReporting.NodeDatabaseErrors import net.corda.confidential.SwapIdentitiesFlow import net.corda.core.CordaException @@ -73,6 +73,7 @@ import net.corda.core.toFuture import net.corda.core.transactions.LedgerTransaction import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.days +import net.corda.core.utilities.millis import net.corda.core.utilities.minutes import net.corda.djvm.source.ApiSource import net.corda.djvm.source.EmptyApi @@ -173,6 +174,7 @@ import net.corda.nodeapi.internal.persistence.RestrictedEntityManager import net.corda.nodeapi.internal.persistence.SchemaMigration import net.corda.nodeapi.internal.persistence.contextDatabase import net.corda.nodeapi.internal.persistence.withoutDatabaseAccess +import net.corda.nodeapi.internal.namedThreadPoolExecutor import org.apache.activemq.artemis.utils.ReusableLatch import org.jolokia.jvmagent.JolokiaServer import org.jolokia.jvmagent.JolokiaServerConfig @@ -188,9 +190,6 @@ import java.util.ArrayList import java.util.Properties import java.util.concurrent.ExecutorService import java.util.concurrent.Executors -import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.ThreadPoolExecutor -import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit.MINUTES import java.util.concurrent.TimeUnit.SECONDS import java.util.function.Consumer @@ -355,7 +354,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, private val cordappServices = MutableClassToInstanceMap.create() private val cordappTelemetryComponents = MutableClassToInstanceMap.create() - private val shutdownExecutor = Executors.newSingleThreadExecutor() + private val shutdownExecutor = Executors.newSingleThreadExecutor(DefaultThreadFactory("Shutdown")) protected abstract val transactionVerifierWorkerCount: Int /** @@ -813,7 +812,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, } else { 1.days } - val executor = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("Network Map Updater")) + val executor = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("NetworkMapPublisher")) executor.submit(object : Runnable { override fun run() { val republishInterval = try { @@ -922,13 +921,12 @@ abstract class AbstractNode(val configuration: NodeConfiguration, } // Start with 1 thread and scale up to the configured thread pool size if needed // Parameters of [ThreadPoolExecutor] based on [Executors.newFixedThreadPool] - return ThreadPoolExecutor( - 1, - numberOfThreads, - 0L, - TimeUnit.MILLISECONDS, - LinkedBlockingQueue(), - ThreadFactoryBuilder().setNameFormat("flow-external-operation-thread").setDaemon(true).build() + return namedThreadPoolExecutor( + corePoolSize = 1, + maxPoolSize = numberOfThreads, + idleKeepAlive = 0.millis, + poolName = "flow-external-operation-thread", + daemonThreads = true ) } diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index 1747740035..f86bd83eda 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -415,12 +415,13 @@ open class Node(configuration: NodeConfiguration, } private fun makeBridgeControlListener(serverAddress: NetworkHostAndPort, networkParameters: NetworkParameters): BridgeControlListener { - val artemisMessagingClientFactory = { + val artemisMessagingClientFactory = { threadPoolName: String -> ArtemisMessagingClient( configuration.p2pSslOptions, serverAddress, networkParameters.maxMessageSize, - failoverCallback = { errorAndTerminate("ArtemisMessagingClient failed. Shutting down.", null) } + failoverCallback = { errorAndTerminate("ArtemisMessagingClient failed. Shutting down.", null) }, + threadPoolName = threadPoolName ) } return BridgeControlListener( @@ -431,7 +432,8 @@ open class Node(configuration: NodeConfiguration, networkParameters.maxMessageSize, configuration.crlCheckSoftFail.toRevocationConfig(), false, - artemisMessagingClientFactory) + artemisMessagingClientFactory + ) } private fun startLocalRpcBroker(securityManager: RPCSecurityManager): BrokerAddresses? { diff --git a/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt b/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt index ffea11f536..ffb21894c1 100644 --- a/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt @@ -9,6 +9,7 @@ import net.corda.core.contracts.StateRef import net.corda.core.contracts.TransactionResolutionException import net.corda.core.contracts.TransactionState import net.corda.core.cordapp.CordappProvider +import net.corda.core.crypto.SecureHash import net.corda.core.internal.SerializedStateAndRef import net.corda.core.internal.uncheckedCast import net.corda.core.node.NetworkParameters @@ -16,8 +17,10 @@ import net.corda.core.node.services.AttachmentStorage import net.corda.core.node.services.IdentityService import net.corda.core.node.services.NetworkParametersService import net.corda.core.node.services.TransactionStorage +import net.corda.core.transactions.BaseTransaction import net.corda.core.transactions.ContractUpgradeWireTransaction import net.corda.core.transactions.NotaryChangeWireTransaction +import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction.Companion.resolveStateRefBinaryComponent @@ -33,25 +36,22 @@ data class ServicesForResolutionImpl( @Throws(TransactionResolutionException::class) override fun loadState(stateRef: StateRef): TransactionState<*> { - val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash) - return stx.resolveBaseTransaction(this).outputs[stateRef.index] + return toBaseTransaction(stateRef.txhash).outputs[stateRef.index] } override fun >> loadStates(input: Iterable, output: C): C { - input.groupBy { it.txhash }.forEach { - val stx = validatedTransactions.getTransaction(it.key) ?: throw TransactionResolutionException(it.key) - val baseTx = stx.resolveBaseTransaction(this) - it.value.mapTo(output) { ref -> StateAndRef(uncheckedCast(baseTx.outputs[ref.index]), ref) } + val baseTxs = HashMap() + return input.mapTo(output) { stateRef -> + val baseTx = baseTxs.computeIfAbsent(stateRef.txhash, ::toBaseTransaction) + StateAndRef(uncheckedCast(baseTx.outputs[stateRef.index]), stateRef) } - return output } @Throws(TransactionResolutionException::class, AttachmentResolutionException::class) override fun loadContractAttachment(stateRef: StateRef): Attachment { // We may need to recursively chase transactions if there are notary changes. fun inner(stateRef: StateRef, forContractClassName: String?): Attachment { - val ctx = validatedTransactions.getTransaction(stateRef.txhash)?.coreTransaction - ?: throw TransactionResolutionException(stateRef.txhash) + val ctx = getSignedTransaction(stateRef.txhash).coreTransaction when (ctx) { is WireTransaction -> { val transactionState = ctx.outRef(stateRef.index).state @@ -76,4 +76,10 @@ data class ServicesForResolutionImpl( } return inner(stateRef, null) } + + private fun toBaseTransaction(txhash: SecureHash): BaseTransaction = getSignedTransaction(txhash).resolveBaseTransaction(this) + + private fun getSignedTransaction(txhash: SecureHash): SignedTransaction { + return validatedTransactions.getTransaction(txhash) ?: throw TransactionResolutionException(txhash) + } } diff --git a/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt b/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt index a58373e8dd..4038a0f2ef 100644 --- a/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt +++ b/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt @@ -135,12 +135,12 @@ class BrokerJaasLoginModule : BaseBrokerJaasLoginModule() { Pair(ArtemisMessagingComponent.NODE_RPC_USER, listOf(RolePrincipal(NODE_RPC_ROLE))) } ArtemisMessagingComponent.PEER_USER -> { - requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." } + val p2pJaasConfig = requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." } requireTls(certificates) // This check is redundant as it was performed already during the SSL handshake - CertificateChainCheckPolicy.RootMustMatch.createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates) - CertificateChainCheckPolicy.RevocationCheck(p2pJaasConfig!!.revocationMode) - .createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates) + CertificateChainCheckPolicy.RootMustMatch + .createCheck(p2pJaasConfig.keyStore, p2pJaasConfig.trustStore) + .checkCertificateChain(certificates) Pair(certificates.first().subjectDN.name, listOf(RolePrincipal(PEER_ROLE))) } else -> { diff --git a/node/src/main/kotlin/net/corda/node/internal/artemis/CertificateChainCheckPolicy.kt b/node/src/main/kotlin/net/corda/node/internal/artemis/CertificateChainCheckPolicy.kt index 0acaae7d0e..6a4a77f071 100644 --- a/node/src/main/kotlin/net/corda/node/internal/artemis/CertificateChainCheckPolicy.kt +++ b/node/src/main/kotlin/net/corda/node/internal/artemis/CertificateChainCheckPolicy.kt @@ -2,17 +2,9 @@ package net.corda.node.internal.artemis import net.corda.core.identity.CordaX500Name import net.corda.core.utilities.contextLogger -import net.corda.nodeapi.internal.crypto.X509CertificateFactory import net.corda.nodeapi.internal.crypto.X509Utilities -import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig -import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl -import net.corda.nodeapi.internal.protonwrapper.netty.certPathToString import java.security.KeyStore -import java.security.cert.CertPathValidator -import java.security.cert.CertPathValidatorException import java.security.cert.CertificateException -import java.security.cert.PKIXBuilderParameters -import java.security.cert.X509CertSelector sealed class CertificateChainCheckPolicy { companion object { @@ -92,33 +84,4 @@ sealed class CertificateChainCheckPolicy { } } } - - class RevocationCheck(val revocationConfig: RevocationConfig) : CertificateChainCheckPolicy() { - constructor(revocationMode: RevocationConfig.Mode) : this(RevocationConfigImpl(revocationMode)) - - override fun createCheck(keyStore: KeyStore, trustStore: KeyStore): Check { - return object : Check { - override fun checkCertificateChain(theirChain: Array) { - // Convert javax.security.cert.X509Certificate to java.security.cert.X509Certificate. - val chain = theirChain.map { X509CertificateFactory().generateCertificate(it.encoded.inputStream()) } - log.info("Check Client Certpath:\r\n${certPathToString(chain.toTypedArray())}") - - // Drop the last certificate which must be a trusted root (validated by RootMustMatch). - // Assume that there is no more trusted roots (or corresponding public keys) in the remaining chain. - // See PKIXValidator.engineValidate() for reference implementation. - val certPath = X509Utilities.buildCertPath(chain.dropLast(1)) - val certPathValidator = CertPathValidator.getInstance("PKIX") - val pkixRevocationChecker = revocationConfig.createPKIXRevocationChecker() - val params = PKIXBuilderParameters(trustStore, X509CertSelector()) - params.addCertPathChecker(pkixRevocationChecker) - try { - certPathValidator.validate(certPath, params) - } catch (ex: CertPathValidatorException) { - log.error("Bad certificate path", ex) - throw ex - } - } - } - } - } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt index f13d1d73bf..b1e1ceb1f0 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt @@ -2,6 +2,7 @@ package net.corda.node.services.events import co.paralleluniverse.fibers.Suspendable import com.google.common.util.concurrent.ListenableFuture +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationOrigin @@ -148,7 +149,7 @@ class NodeSchedulerService(private val clock: CordaClock, // from the database private val startingStateRefs: MutableSet = ConcurrentHashMap.newKeySet() private val mutex = ThreadBox(InnerState()) - private val schedulerTimerExecutor = Executors.newSingleThreadExecutor() + private val schedulerTimerExecutor = Executors.newSingleThreadExecutor(DefaultThreadFactory("SchedulerService")) // if there's nothing to do, check every minute if something fell through the cracks. // any new state should trigger a reschedule immediately if nothing is scheduled, so I would not expect diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt index 4b0689d18e..c0a6ed0388 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt @@ -7,9 +7,16 @@ import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug -import net.corda.node.internal.artemis.* +import net.corda.node.internal.artemis.ArtemisBroker +import net.corda.node.internal.artemis.BrokerAddresses +import net.corda.node.internal.artemis.BrokerJaasLoginModule import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.NODE_P2P_ROLE import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.PEER_ROLE +import net.corda.node.internal.artemis.NodeJaasConfig +import net.corda.node.internal.artemis.P2PJaasConfig +import net.corda.node.internal.artemis.SecureArtemisConfiguration +import net.corda.node.internal.artemis.UserValidationPlugin +import net.corda.node.internal.artemis.isBindingError import net.corda.node.services.config.NodeConfiguration import net.corda.node.utilities.artemis.startSynchronously import net.corda.nodeapi.internal.AmqpMessageSizeChecksInterceptor @@ -21,7 +28,10 @@ import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.SECURITY_INVALIDATION_INTERVAL import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pAcceptorTcpTransport import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation import net.corda.nodeapi.internal.requireOnDefaultFileSystem +import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl @@ -32,9 +42,7 @@ import org.apache.activemq.artemis.core.security.Role import org.apache.activemq.artemis.core.server.ActiveMQServer import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager -import java.io.IOException import java.lang.Long.max -import java.security.KeyStoreException import javax.annotation.concurrent.ThreadSafe import javax.security.auth.login.AppConfigurationEntry import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED @@ -57,8 +65,10 @@ class ArtemisMessagingServer(private val config: NodeConfiguration, private val messagingServerAddress: NetworkHostAndPort, private val maxMessageSize: Int, private val journalBufferTimeout : Int? = null, - private val threadPoolName: String = "ArtemisServer", - private val trace: Boolean = false) : ArtemisBroker, SingletonSerializeAsToken() { + private val threadPoolName: String = "P2PServer", + private val trace: Boolean = false, + private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON, + private val remotingThreads: Int? = null) : ArtemisBroker, SingletonSerializeAsToken() { companion object { private val log = contextLogger() } @@ -92,7 +102,7 @@ class ArtemisMessagingServer(private val config: NodeConfiguration, override val started: Boolean get() = activeMQServer.isStarted - @Throws(IOException::class, AddressBindingException::class, KeyStoreException::class) + @Suppress("ThrowsCount") private fun configureAndStartServer() { val artemisConfig = createArtemisConfig() val securityManager = createArtemisSecurityManager() @@ -132,11 +142,23 @@ class ArtemisMessagingServer(private val config: NodeConfiguration, // The transaction cache is configurable, and drives other cache sizes. globalMaxSize = max(config.transactionCacheSizeBytes, 10L * maxMessageSize) + val revocationMode = if (config.crlCheckArtemisServer) { + if (config.crlCheckSoftFail) RevocationConfig.Mode.SOFT_FAIL else RevocationConfig.Mode.HARD_FAIL + } else { + RevocationConfig.Mode.OFF + } + val trustManagerFactory = trustManagerFactoryWithRevocation( + config.p2pSslOptions.trustStore.get(), + RevocationConfigImpl(revocationMode), + distPointCrlSource + ) addAcceptorConfiguration(p2pAcceptorTcpTransport( NetworkHostAndPort(messagingServerAddress.host, messagingServerAddress.port), config.p2pSslOptions, + trustManagerFactory, threadPoolName = threadPoolName, - trace = trace + trace = trace, + remotingThreads = remotingThreads )) // Enable built in message deduplication. Note we still have to do our own as the delayed commits // and our own definition of commit mean that the built in deduplication cannot remove all duplicates. diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt index 4143b20273..9e9c7ca081 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt @@ -10,6 +10,8 @@ import io.netty.handler.ssl.SslHandshakeTimeoutException import net.corda.core.internal.declaredField import net.corda.core.utilities.contextLogger import net.corda.nodeapi.internal.ArtemisTcpTransport +import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor +import net.corda.nodeapi.internal.setThreadPoolName import org.apache.activemq.artemis.api.core.BaseInterceptor import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor import org.apache.activemq.artemis.core.server.balancing.RedirectHandler @@ -19,14 +21,18 @@ import org.apache.activemq.artemis.spi.core.remoting.Acceptor import org.apache.activemq.artemis.spi.core.remoting.AcceptorFactory import org.apache.activemq.artemis.spi.core.remoting.BufferHandler import org.apache.activemq.artemis.spi.core.remoting.ServerConnectionLifeCycleListener +import org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactoryProvider +import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactoryProvider import org.apache.activemq.artemis.utils.ConfigurationHelper import org.apache.activemq.artemis.utils.actors.OrderedExecutor +import java.net.SocketAddress import java.nio.channels.ClosedChannelException import java.time.Duration import java.util.concurrent.Executor import java.util.concurrent.ScheduledExecutorService import java.util.regex.Pattern import javax.net.ssl.SSLEngine +import javax.net.ssl.SSLPeerUnverifiedException @Suppress("unused") // Used via reflection in ArtemisTcpTransport class NodeNettyAcceptorFactory : AcceptorFactory { @@ -36,10 +42,23 @@ class NodeNettyAcceptorFactory : AcceptorFactory { handler: BufferHandler?, listener: ServerConnectionLifeCycleListener?, threadPool: Executor, - scheduledThreadPool: ScheduledExecutorService?, + scheduledThreadPool: ScheduledExecutorService, protocolMap: MutableMap, RedirectHandler<*>>>?): Acceptor { + val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Acceptor", configuration) + threadPool.setThreadPoolName("$threadPoolName-artemis") + scheduledThreadPool.setThreadPoolName("$threadPoolName-artemis-scheduler") val failureExecutor = OrderedExecutor(threadPool) - return NodeNettyAcceptor(name, clusterConnection, configuration, handler, listener, scheduledThreadPool, failureExecutor, protocolMap) + return NodeNettyAcceptor( + name, + clusterConnection, + configuration, + handler, + listener, + scheduledThreadPool, + failureExecutor, + protocolMap, + "$threadPoolName-netty" + ) } @@ -50,14 +69,21 @@ class NodeNettyAcceptorFactory : AcceptorFactory { listener: ServerConnectionLifeCycleListener?, scheduledThreadPool: ScheduledExecutorService?, failureExecutor: Executor, - protocolMap: MutableMap, RedirectHandler<*>>>?) : + protocolMap: MutableMap, RedirectHandler<*>>>?, + private val threadPoolName: String) : NettyAcceptor(name, clusterConnection, configuration, handler, listener, scheduledThreadPool, failureExecutor, protocolMap) { companion object { private val defaultThreadPoolNamePattern = Pattern.compile("""Thread-(\d+) \(activemq-netty-threads\)""") + + init { + // Make sure Artemis isn't using another (Open)SSLContextFactory + check(SSLContextFactoryProvider.getSSLContextFactory() is NodeSSLContextFactory) + check(OpenSSLContextFactoryProvider.getOpenSSLContextFactory() is NodeOpenSSLContextFactory) + } } - private val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "NodeNettyAcceptor", configuration) + private val sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName) private val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration) @Synchronized @@ -71,11 +97,17 @@ class NodeNettyAcceptorFactory : AcceptorFactory { } } + @Synchronized + override fun stop() { + super.stop() + sslDelegatedTaskExecutor.shutdown() + } + @Synchronized override fun getSslHandler(alloc: ByteBufAllocator?, peerHost: String?, peerPort: Int): SslHandler { applyThreadPoolName() val engine = super.getSslHandler(alloc, peerHost, peerPort).engine() - val sslHandler = NodeAcceptorSslHandler(engine, trace) + val sslHandler = NodeAcceptorSslHandler(engine, sslDelegatedTaskExecutor, trace) val handshakeTimeout = configuration[ArtemisTcpTransport.SSL_HANDSHAKE_TIMEOUT_NAME] as Duration? if (handshakeTimeout != null) { sslHandler.handshakeTimeoutMillis = handshakeTimeout.toMillis() @@ -95,13 +127,15 @@ class NodeNettyAcceptorFactory : AcceptorFactory { } - private class NodeAcceptorSslHandler(engine: SSLEngine, private val trace: Boolean) : SslHandler(engine) { + private class NodeAcceptorSslHandler(engine: SSLEngine, + delegatedTaskExecutor: Executor, + private val trace: Boolean) : SslHandler(engine, delegatedTaskExecutor) { companion object { private val logger = contextLogger() } override fun handlerAdded(ctx: ChannelHandlerContext) { - logHandshake() + logHandshake(ctx.channel().remoteAddress()) super.handlerAdded(ctx) // Unfortunately NettyAcceptor does not let us add extra child handlers, so we have to add our logger this way. if (trace) { @@ -109,17 +143,22 @@ class NodeNettyAcceptorFactory : AcceptorFactory { } } - private fun logHandshake() { + private fun logHandshake(remoteAddress: SocketAddress) { val start = System.currentTimeMillis() handshakeFuture().addListener { val duration = System.currentTimeMillis() - start + val peer = try { + engine().session.peerPrincipal + } catch (e: SSLPeerUnverifiedException) { + remoteAddress + } when { - it.isSuccess -> logger.info("SSL handshake completed in ${duration}ms with ${engine().session.peerPrincipal}") - it.isCancelled -> logger.warn("SSL handshake cancelled after ${duration}ms") + it.isSuccess -> logger.info("SSL handshake completed in ${duration}ms with $peer") + it.isCancelled -> logger.warn("SSL handshake cancelled after ${duration}ms with $peer") else -> when (it.cause()) { - is ClosedChannelException -> logger.warn("SSL handshake closed early after ${duration}ms") - is SslHandshakeTimeoutException -> logger.warn("SSL handshake timed out after ${duration}ms") - else -> logger.warn("SSL handshake failed after ${duration}ms", it.cause()) + is ClosedChannelException -> logger.warn("SSL handshake closed early after ${duration}ms with $peer") + is SslHandshakeTimeoutException -> logger.warn("SSL handshake timed out after ${duration}ms with $peer") + else -> logger.warn("SSL handshake failed after ${duration}ms with $peer", it.cause()) } } } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeSSLContextFactory.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeSSLContextFactory.kt new file mode 100644 index 0000000000..38f60d7a57 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeSSLContextFactory.kt @@ -0,0 +1,59 @@ +package net.corda.node.services.messaging + +import io.netty.handler.ssl.SslContext +import io.netty.handler.ssl.SslContextBuilder +import io.netty.handler.ssl.SslProvider +import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.TRUST_MANAGER_FACTORY_NAME +import net.corda.nodeapi.internal.config.CertificateStore +import net.corda.nodeapi.internal.protonwrapper.netty.createAndInitSslContext +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultOpenSSLContextFactory +import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultSSLContextFactory +import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextConfig +import java.nio.file.Paths +import javax.net.ssl.KeyManagerFactory +import javax.net.ssl.SSLContext +import javax.net.ssl.TrustManagerFactory + +class NodeSSLContextFactory : DefaultSSLContextFactory() { + override fun getSSLContext(config: SSLContextConfig, additionalOpts: Map): SSLContext { + val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory? + return if (trustManagerFactory != null) { + createAndInitSslContext(loadKeyManagerFactory(config), trustManagerFactory) + } else { + super.getSSLContext(config, additionalOpts) + } + } + + override fun getPriority(): Int { + // We make sure this factory is the one that's chosen, so any sufficiently large value will do. + return 15 + } +} + + +class NodeOpenSSLContextFactory : DefaultOpenSSLContextFactory() { + override fun getServerSslContext(config: SSLContextConfig, additionalOpts: Map): SslContext { + val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory? + return if (trustManagerFactory != null) { + SslContextBuilder + .forServer(loadKeyManagerFactory(config)) + .sslProvider(SslProvider.OPENSSL) + .trustManager(trustManagerFactory) + .build() + } else { + super.getServerSslContext(config, additionalOpts) + } + } + + override fun getPriority(): Int { + // We make sure this factory is the one that's chosen, so any sufficiently large value will do. + return 15 + } +} + + +private fun loadKeyManagerFactory(config: SSLContextConfig): KeyManagerFactory { + val keyStore = CertificateStore.fromFile(Paths.get(config.keystorePath), config.keystorePassword, config.keystorePassword, false) + return keyManagerFactory(keyStore) +} diff --git a/node/src/main/kotlin/net/corda/node/services/network/NetworkMapUpdater.kt b/node/src/main/kotlin/net/corda/node/services/network/NetworkMapUpdater.kt index fac12a9343..584b050425 100644 --- a/node/src/main/kotlin/net/corda/node/services/network/NetworkMapUpdater.kt +++ b/node/src/main/kotlin/net/corda/node/services/network/NetworkMapUpdater.kt @@ -74,7 +74,7 @@ class NetworkMapUpdater(private val networkMapCache: NetworkMapCacheInternal, } private val parametersUpdatesTrack = PublishSubject.create() - private val networkMapPoller = ScheduledThreadPoolExecutor(1, NamedThreadFactory("Network Map Updater Thread")).apply { + private val networkMapPoller = ScheduledThreadPoolExecutor(1, NamedThreadFactory("NetworkMapUpdater")).apply { executeExistingDelayedTasksAfterShutdownPolicy = false } private var newNetworkParameters: Pair? = null @@ -261,9 +261,12 @@ class NetworkMapUpdater(private val networkMapCache: NetworkMapCacheInternal, //as HTTP GET is mostly IO bound, use more threads than CPU's //maximum threads to use = 24, as if we did not limit this on large machines it could result in 100's of concurrent requests val threadsToUseForNetworkMapDownload = min(Runtime.getRuntime().availableProcessors() * 4, 24) - val executorToUseForDownloadingNodeInfos = Executors.newFixedThreadPool(threadsToUseForNetworkMapDownload, NamedThreadFactory("NetworkMapUpdaterNodeInfoDownloadThread")) + val executorToUseForDownloadingNodeInfos = Executors.newFixedThreadPool( + threadsToUseForNetworkMapDownload, + NamedThreadFactory("NetworkMapUpdaterNodeInfoDownload") + ) //DB insert is single threaded - use a single threaded executor for it. - val executorToUseForInsertionIntoDB = Executors.newSingleThreadExecutor(NamedThreadFactory("NetworkMapUpdateDBInsertThread")) + val executorToUseForInsertionIntoDB = Executors.newSingleThreadExecutor(NamedThreadFactory("NetworkMapUpdateDBInsert")) val hashesToFetch = (allHashesFromNetworkMap - allNodeHashes) val networkMapDownloadStartTime = System.currentTimeMillis() if (hashesToFetch.isNotEmpty()) { diff --git a/node/src/main/kotlin/net/corda/node/services/rpc/InternalRPCMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/rpc/InternalRPCMessagingClient.kt index 8ed025549c..e48cdf16c0 100644 --- a/node/src/main/kotlin/net/corda/node/services/rpc/InternalRPCMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/rpc/InternalRPCMessagingClient.kt @@ -22,8 +22,7 @@ class InternalRPCMessagingClient(val sslConfig: MutualSslConfiguration, val serv private var rpcServer: RPCServer? = null fun init(rpcOps: List, securityManager: RPCSecurityManager, cacheFactory: NamedCacheFactory) = synchronized(this) { - - val tcpTransport = ArtemisTcpTransport.rpcInternalClientTcpTransport(serverAddress, sslConfig) + val tcpTransport = ArtemisTcpTransport.rpcInternalClientTcpTransport(serverAddress, sslConfig, threadPoolName = "RPCClient") locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply { // Never time out on our loopback Artemis connections. If we switch back to using the InVM transport this // would be the default and the two lines below can be deleted. diff --git a/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt index 57ea80aadd..13af138c8e 100644 --- a/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt @@ -30,10 +30,10 @@ internal class RpcBrokerConfiguration(baseDirectory: Path, maxMessageSize: Int, setDirectories(baseDirectory) val acceptorConfigurationsSet = mutableSetOf( - rpcAcceptorTcpTransport(address, sslOptions, useSsl) + rpcAcceptorTcpTransport(address, sslOptions, enableSSL = useSsl, threadPoolName = "RPCServer") ) adminAddress?.let { - acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration) + acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration, threadPoolName = "RPCServerAdmin") } acceptorConfigurations = acceptorConfigurationsSet diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt index c08515ab0e..734b2b8234 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt @@ -1,5 +1,6 @@ package net.corda.node.services.statemachine +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.core.flows.FlowSession import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowStateMachine @@ -22,10 +23,6 @@ internal class FlowMonitor( ) : LifecycleSupport { private companion object { - private fun defaultScheduler(): ScheduledExecutorService { - return Executors.newSingleThreadScheduledExecutor() - } - private val logger = loggerFor() } @@ -36,7 +33,7 @@ internal class FlowMonitor( override fun start() { synchronized(this) { if (scheduler == null) { - scheduler = defaultScheduler() + scheduler = Executors.newSingleThreadScheduledExecutor(DefaultThreadFactory("FlowMonitor")) shutdownScheduler = true } scheduler!!.scheduleAtFixedRate({ logFlowsWaitingForParty() }, 0, monitoringPeriod.toMillis(), TimeUnit.MILLISECONDS) diff --git a/node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactory b/node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactory new file mode 100644 index 0000000000..6b69d9d3ff --- /dev/null +++ b/node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactory @@ -0,0 +1 @@ +net.corda.node.services.messaging.NodeOpenSSLContextFactory diff --git a/node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactory b/node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactory new file mode 100644 index 0000000000..59e57dca26 --- /dev/null +++ b/node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactory @@ -0,0 +1 @@ +net.corda.node.services.messaging.NodeSSLContextFactory diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index 6d8c7e8392..fa516073fb 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -158,7 +158,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { bobNode.internals.disableDBCloseOnStop() bobNode.database.transaction { - VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, cashIssuer) + VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, cashIssuer, atMostThisManyStates = 10) } val alicesFakePaper = aliceNode.database.transaction { @@ -267,7 +267,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { val issuer = bank.ref(1, 2, 3) bobNode.database.transaction { - VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, issuer) + VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, issuer, atMostThisManyStates = 10) } val alicesFakePaper = aliceNode.database.transaction { fillUpForSeller(false, issuer, alice, diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt index 6f0fa3278c..1930e7ffd8 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt @@ -244,7 +244,6 @@ class FlowSoftLocksTests { 100.DOLLARS, bankNode.services, thisManyStates, - thisManyStates, cashIssuer ) } diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index 6344331c28..153e51110e 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -22,7 +22,6 @@ import net.corda.finance.contracts.Commodity import net.corda.finance.contracts.DealState import net.corda.finance.contracts.asset.Cash import net.corda.finance.schemas.CashSchemaV1 -import net.corda.finance.schemas.CashSchemaV1.PersistentCashState import net.corda.finance.schemas.CommercialPaperSchemaV1 import net.corda.finance.test.SampleCashSchemaV2 import net.corda.finance.test.SampleCashSchemaV3 @@ -197,8 +196,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } protected fun consumeCash(amount: Amount) = vaultFiller.consumeCash(amount, CHARLIE) - private fun setUpDb(_database: CordaPersistence, delay: Long = 0) { - _database.transaction { + + private fun setUpDb(database: CordaPersistence, delay: Long = 0) { + database.transaction { // create new states vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 10, DUMMY_CASH_ISSUER) val linearStatesXYZ = vaultFiller.fillWithSomeTestLinearStates(1, "XYZ") @@ -444,7 +444,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.state.data.linearNumber }.sortedBy { it.ref.txhash }.sortedBy { it.ref.index }).isEqualTo(allStates) } - (1..3).forEach { + repeat(3) { val newAllStates = vaultService.queryBy(sorting = sorting, criteria = criteria).states assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates) assertThat(newAllStates).containsExactlyElementsOf(allStates) @@ -485,7 +485,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.ref.txhash }.sortedByDescending { it.ref.index }).isEqualTo(allStates) } - (1..3).forEach { + repeat(3) { val newAllStates = vaultService.queryBy(sorting = sorting, criteria = criteria).states assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates) assertThat(newAllStates).containsExactlyElementsOf(allStates) @@ -638,7 +638,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } val sorted = results.states.sortedBy { it.ref.toString() } assertThat(results.states).isEqualTo(sorted) - assertThat(results.states).allSatisfy { !consumed.contains(it.ref.txhash) } + assertThat(results.states).allSatisfy { assertThat(consumed).doesNotContain(it.ref.txhash) } } } @@ -1537,7 +1537,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")) // count fungible assets val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } - val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count) + val countCriteria = VaultCustomQueryCriteria(count) val fungibleStateCount = vaultService.queryBy>(countCriteria).otherResults.single() as Long assertThat(fungibleStateCount).isEqualTo(10L) @@ -1563,7 +1563,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } // count fungible assets - val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) + val countCriteria = VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) val fungibleStateCount = vaultService.queryBy>(countCriteria).otherResults.single() as Long assertThat(fungibleStateCount).isEqualTo(10L) @@ -1583,7 +1583,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // UNCONSUMED states (default) // count fungible assets - val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED) + val countCriteriaUnconsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED) val fungibleStateCountUnconsumed = vaultService.queryBy>(countCriteriaUnconsumed).otherResults.single() as Long assertThat(fungibleStateCountUnconsumed.toInt()).isEqualTo(10 - cashUpdates.consumed.size + cashUpdates.produced.size) @@ -1598,7 +1598,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // CONSUMED states // count fungible assets - val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED) + val countCriteriaConsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED) val fungibleStateCountConsumed = vaultService.queryBy>(countCriteriaConsumed).otherResults.single() as Long assertThat(fungibleStateCountConsumed.toInt()).isEqualTo(cashUpdates.consumed.size) @@ -1622,7 +1622,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val start = TODAY val end = TODAY.plus(30, ChronoUnit.DAYS) val recordedBetweenExpression = TimeCondition( - QueryCriteria.TimeInstantType.RECORDED, + TimeInstantType.RECORDED, ColumnPredicate.Between(start, end)) val criteria = VaultQueryCriteria(timeCondition = recordedBetweenExpression) val results = vaultService.queryBy(criteria) @@ -1632,7 +1632,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // Future val startFuture = TODAY.plus(1, ChronoUnit.DAYS) val recordedBetweenExpressionFuture = TimeCondition( - QueryCriteria.TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end)) + TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end)) val criteriaFuture = VaultQueryCriteria(timeCondition = recordedBetweenExpressionFuture) assertThat(vaultService.queryBy(criteriaFuture).states).isEmpty() } @@ -1648,7 +1648,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { consumeCash(100.DOLLARS) val asOfDateTime = TODAY val consumedAfterExpression = TimeCondition( - QueryCriteria.TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime)) + TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime)) val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED, timeCondition = consumedAfterExpression) val results = vaultService.queryBy(criteria) @@ -1711,6 +1711,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } // pagination: invalid page size + @Suppress("INTEGER_OVERFLOW") @Test(timeout=300_000) fun `invalid page size`() { expectedEx.expect(VaultQueryException::class.java) @@ -1718,8 +1719,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { database.transaction { vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 100, DUMMY_CASH_ISSUER) - @Suppress("EXPECTED_CONDITION") - val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, @Suppress("INTEGER_OVERFLOW") Integer.MAX_VALUE + 1) // overflow = -2147483648 + val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, Integer.MAX_VALUE + 1) // overflow = -2147483648 val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) vaultService.queryBy(criteria, paging = pagingSpec) } @@ -1839,9 +1839,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties { println("$index : $any") } assertThat(results.otherResults.size).isEqualTo(402) - val instants = results.otherResults.filter { it is Instant }.map { it as Instant } + val instants = results.otherResults.filterIsInstance() assertThat(instants).isSorted - val longs = results.otherResults.filter { it is Long }.map { it as Long } + val longs = results.otherResults.filterIsInstance() assertThat(longs.size).isEqualTo(201) assertThat(instants.size).isEqualTo(201) assertThat(longs.sum()).isEqualTo(20100L) @@ -1969,8 +1969,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties { fun `LinearStateQueryCriteria returns empty resultset without errors if there is an empty list after the 'in' clause`() { database.transaction { val uid = UniqueIdentifier("999") - vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, uniqueIdentifier = uid) - vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, externalId = "1234") + vaultFiller.fillWithSomeTestLinearStates(txCount = 1, uniqueIdentifier = uid) + vaultFiller.fillWithSomeTestLinearStates(txCount = 1, externalId = "1234") val uuidCriteria = LinearStateQueryCriteria(uuid = listOf(uid.id)) val externalIdCriteria = LinearStateQueryCriteria(externalId = listOf("1234")) @@ -2119,6 +2119,26 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } } + @Test(timeout = 300_000) + fun `unconsumed states which are globally unordered across multiple transactions sorted by custom attribute`() { + val linearNumbers = Array(2) { LongArray(2) } + // Make sure states from the same transaction are not given consecutive linear numbers. + linearNumbers[0][0] = 1L + linearNumbers[0][1] = 3L + linearNumbers[1][0] = 2L + linearNumbers[1][1] = 4L + + val results = database.transaction { + vaultFiller.fillWithTestStates(txCount = 2, statesPerTx = 2) { participantsToUse, txIndex, stateIndex -> + DummyLinearContract.State(participants = participantsToUse, linearNumber = linearNumbers[txIndex][stateIndex]) + } + + val sortColumn = Sort.SortColumn(SortAttribute.Custom(DummyLinearStateSchemaV1.PersistentDummyLinearState::class.java, "linearNumber")) + vaultService.queryBy(VaultQueryCriteria(), sorting = Sort(setOf(sortColumn))) + } + assertThat(results.states.map { it.state.data.linearNumber }).isEqualTo(listOf(1L, 2L, 3L, 4L)) + } + @Test(timeout=300_000) fun `return consumed linear states for a given linear id`() { database.transaction { @@ -2448,7 +2468,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { services.recordTransactions(commercialPaper2) val ccyIndex = builder { CommercialPaperSchemaV1.PersistentCommercialPaperState::currency.equal(USD.currencyCode) } - val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) + val criteria1 = VaultCustomQueryCriteria(ccyIndex) val result = vaultService.queryBy(criteria1) @@ -2491,9 +2511,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val maturityIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::maturity.greaterThanOrEqual(TEST_TX_TIME + 30.days) val faceValueIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::faceValue.greaterThanOrEqual(10000L) - val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) - val criteria2 = QueryCriteria.VaultCustomQueryCriteria(maturityIndex) - val criteria3 = QueryCriteria.VaultCustomQueryCriteria(faceValueIndex) + val criteria1 = VaultCustomQueryCriteria(ccyIndex) + val criteria2 = VaultCustomQueryCriteria(maturityIndex) + val criteria3 = VaultCustomQueryCriteria(faceValueIndex) vaultService.queryBy(criteria1.and(criteria3).and(criteria2)) } @@ -2516,8 +2536,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val generalCriteria = VaultQueryCriteria(Vault.StateStatus.ALL) val results = builder { - val currencyIndex = PersistentCashState::currency.equal(USD.currencyCode) - val quantityIndex = PersistentCashState::pennies.greaterThanOrEqual(10L) + val currencyIndex = CashSchemaV1.PersistentCashState::currency.equal(USD.currencyCode) + val quantityIndex = CashSchemaV1.PersistentCashState::pennies.greaterThanOrEqual(10L) val customCriteria1 = VaultCustomQueryCriteria(currencyIndex) val customCriteria2 = VaultCustomQueryCriteria(quantityIndex) @@ -2768,7 +2788,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // Enrich and override QueryCriteria with additional default attributes (such as soft locks) val enrichedCriteria = VaultQueryCriteria(contractStateTypes = setOf(DealState::class.java), // enrich - softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())), + softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())), status = Vault.StateStatus.UNCONSUMED) // override // Sorting val sortAttribute = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF) @@ -3114,7 +3134,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { assertThat(snapshot.states).hasSize(0) val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states this.session.flush() - vaultFiller.consumeLinearStates(states.toList()) + vaultFiller.consumeStates(states) updates } @@ -3137,7 +3157,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { assertThat(snapshot.states).hasSize(0) val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states this.session.flush() - vaultFiller.consumeLinearStates(states.toList()) + vaultFiller.consumeStates(states) updates } @@ -3160,7 +3180,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { assertThat(snapshot.states).hasSize(0) val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states this.session.flush() - vaultFiller.consumeLinearStates(states.toList()) + vaultFiller.consumeStates(states) updates } diff --git a/testing/core-test-utils/src/main/kotlin/net/corda/testing/core/TestUtils.kt b/testing/core-test-utils/src/main/kotlin/net/corda/testing/core/TestUtils.kt index 1bcf1ac389..175f52f5ae 100644 --- a/testing/core-test-utils/src/main/kotlin/net/corda/testing/core/TestUtils.kt +++ b/testing/core-test-utils/src/main/kotlin/net/corda/testing/core/TestUtils.kt @@ -1,30 +1,50 @@ -@file:Suppress("UNUSED_PARAMETER") @file:JvmName("TestUtils") +@file:Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod", "LongParameterList") package net.corda.testing.core import net.corda.core.contracts.PartyAndReference import net.corda.core.contracts.StateRef -import net.corda.core.crypto.* +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.SignatureScheme +import net.corda.core.crypto.toStringShort import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate +import net.corda.core.internal.toX500Name import net.corda.core.internal.unspecifiedCountry import net.corda.core.node.NodeInfo import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.millis +import net.corda.core.utilities.minutes +import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA +import net.corda.coretesting.internal.DEV_ROOT_CA import net.corda.nodeapi.internal.createDevNodeCa import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.CertificateType import net.corda.nodeapi.internal.crypto.X509Utilities -import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA -import net.corda.coretesting.internal.DEV_ROOT_CA +import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames +import org.bouncycastle.asn1.x509.CRLReason +import org.bouncycastle.asn1.x509.DistributionPointName +import org.bouncycastle.asn1.x509.Extension +import org.bouncycastle.asn1.x509.ExtensionsGenerator +import org.bouncycastle.asn1.x509.GeneralName +import org.bouncycastle.asn1.x509.GeneralNames +import org.bouncycastle.asn1.x509.IssuingDistributionPoint +import org.bouncycastle.cert.jcajce.JcaX509CRLConverter +import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils +import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder import java.math.BigInteger +import java.net.URI import java.security.KeyPair import java.security.PublicKey +import java.security.cert.X509CRL import java.security.cert.X509Certificate import java.time.Duration import java.time.Instant +import java.util.* import java.util.concurrent.atomic.AtomicInteger import kotlin.test.fail @@ -109,6 +129,44 @@ fun getTestPartyAndCertificate(name: CordaX500Name, publicKey: PublicKey): Party return getTestPartyAndCertificate(Party(name, publicKey)) } +fun createCRL(issuer: CertificateAndKeyPair, + revokedCerts: List, + issuingDistPoint: URI? = null, + thisUpdate: Instant = Instant.now(), + nextUpdate: Instant = thisUpdate + 5.minutes, + indirect: Boolean = false, + revocationDate: Instant = thisUpdate, + crlReason: Int = CRLReason.keyCompromise, + signatureAlgorithm: String = "SHA256withECDSA"): X509CRL { + val builder = JcaX509v2CRLBuilder(issuer.certificate.subjectX500Principal, Date.from(thisUpdate)) + val extensionUtils = JcaX509ExtensionUtils() + builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(issuer.certificate)) + // This is required and needs to match the certificate settings with respect to being indirect + builder.addExtension( + Extension.issuingDistributionPoint, + true, + IssuingDistributionPoint( + issuingDistPoint?.let { DistributionPointName(toGeneralNames(it.toString(), GeneralName.uniformResourceIdentifier)) }, + indirect, + false + ) + ) + builder.setNextUpdate(Date.from(nextUpdate)) + for (revokedCert in revokedCerts) { + val extensionsGenerator = ExtensionsGenerator() + extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(crlReason)) + // Certificate issuer is required for indirect CRL + extensionsGenerator.addExtension( + Extension.certificateIssuer, + true, + GeneralNames(GeneralName(revokedCert.issuerX500Principal.toX500Name())) + ) + builder.addCRLEntry(revokedCert.serialNumber, Date.from(revocationDate), extensionsGenerator.generate()) + } + val bcProvider = Crypto.findProvider("BC") + val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(bcProvider).build(issuer.keyPair.private) + return JcaX509CRLConverter().setProvider(bcProvider).getCRL(builder.build(signer)) +} private val count = AtomicInteger(0) /** @@ -188,7 +246,6 @@ fun NodeInfo.singleIdentity(): Party = singleIdentityAndCert().party * The above will test our expectation that the getWaitingFlows action was executed successfully considering * that it may take a few hundreds of milliseconds for the flow state machine states to settle. */ -@Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod") fun executeTest( timeout: Duration, cleanup: (() -> Unit)? = null, diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/CrlServer.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/CrlServer.kt index 744f33c72c..b6dee805fa 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/CrlServer.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/CrlServer.kt @@ -4,30 +4,26 @@ package net.corda.testing.node.internal.network import net.corda.core.crypto.Crypto import net.corda.core.internal.CertRole +import net.corda.core.internal.toX500Name import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.contextLogger import net.corda.core.utilities.days import net.corda.core.utilities.minutes -import net.corda.core.utilities.seconds import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA import net.corda.coretesting.internal.DEV_ROOT_CA import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.ContentSignerBuilder import net.corda.nodeapi.internal.crypto.X509Utilities +import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames import net.corda.nodeapi.internal.crypto.certificateType import net.corda.nodeapi.internal.crypto.toJca -import org.bouncycastle.asn1.x500.X500Name +import net.corda.testing.core.createCRL import org.bouncycastle.asn1.x509.CRLDistPoint import org.bouncycastle.asn1.x509.DistributionPoint import org.bouncycastle.asn1.x509.DistributionPointName import org.bouncycastle.asn1.x509.Extension import org.bouncycastle.asn1.x509.GeneralName import org.bouncycastle.asn1.x509.GeneralNames -import org.bouncycastle.asn1.x509.IssuingDistributionPoint -import org.bouncycastle.asn1.x509.ReasonFlags -import org.bouncycastle.cert.jcajce.JcaX509CRLConverter -import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils -import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder -import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.ServerConnector import org.eclipse.jetty.server.handler.HandlerCollection @@ -36,11 +32,12 @@ import org.eclipse.jetty.servlet.ServletHolder import org.glassfish.jersey.server.ResourceConfig import org.glassfish.jersey.servlet.ServletContainer import java.io.Closeable -import java.math.BigInteger import java.net.InetSocketAddress +import java.net.URI import java.security.KeyPair import java.security.cert.X509CRL import java.security.cert.X509Certificate +import java.time.Duration import java.util.* import javax.security.auth.x500.X500Principal import javax.ws.rs.GET @@ -51,7 +48,7 @@ import kotlin.collections.ArrayList class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { companion object { - private const val SIGNATURE_ALGORITHM = "SHA256withECDSA" + private val logger = contextLogger() const val NODE_CRL = "node.crl" const val FORBIDDEN_CRL = "forbidden.crl" @@ -72,8 +69,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { null ) if (crlDistPoint != null) { - val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint))) - val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(X500Name.getInstance(it.encoded))) } + val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier)) + val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(it.toX500Name())) } val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames) builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint))) } @@ -87,14 +84,17 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { } } - val revokedNodeCerts: MutableList = ArrayList() - val revokedIntermediateCerts: MutableList = ArrayList() + val revokedNodeCerts: MutableList = ArrayList() + val revokedIntermediateCerts: MutableList = ArrayList() val rootCa: CertificateAndKeyPair = DEV_ROOT_CA private lateinit var _intermediateCa: CertificateAndKeyPair val intermediateCa: CertificateAndKeyPair get() = _intermediateCa + @Volatile + var delay: Duration? = null + val hostAndPort: NetworkHostAndPort get() = server.connectors.mapNotNull { it as? ServerConnector } .map { NetworkHostAndPort(it.host, it.localPort) } @@ -106,7 +106,7 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { DEV_INTERMEDIATE_CA.certificate.withCrlDistPoint(rootCa.keyPair, "http://$hostAndPort/crl/$INTERMEDIATE_CRL"), DEV_INTERMEDIATE_CA.keyPair ) - println("Network management web services started on $hostAndPort") + logger.info("Network management web services started on $hostAndPort") } fun replaceNodeCertDistPoint(nodeCaCert: X509Certificate, @@ -115,29 +115,20 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { return nodeCaCert.withCrlDistPoint(intermediateCa.keyPair, nodeCaCrlDistPoint, crlIssuer) } - fun createRevocationList(signatureAlgorithm: String, - ca: CertificateAndKeyPair, - endpoint: String, - indirect: Boolean, - serialNumbers: List): X509CRL { - println("Generating CRL for $endpoint") - val builder = JcaX509v2CRLBuilder(ca.certificate.subjectX500Principal, Date(System.currentTimeMillis() - 1.minutes.toMillis())) - val extensionUtils = JcaX509ExtensionUtils() - builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(ca.certificate)) - val issuingDistPointName = GeneralName(GeneralName.uniformResourceIdentifier, "http://$hostAndPort/crl/$endpoint") - // This is required and needs to match the certificate settings with respect to being indirect - val issuingDistPoint = IssuingDistributionPoint(DistributionPointName(GeneralNames(issuingDistPointName)), indirect, false) - builder.addExtension(Extension.issuingDistributionPoint, true, issuingDistPoint) - builder.setNextUpdate(Date(System.currentTimeMillis() + 1.seconds.toMillis())) - serialNumbers.forEach { - builder.addCRLEntry(it, Date(System.currentTimeMillis() - 10.minutes.toMillis()), ReasonFlags.certificateHold) - } - val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(Crypto.findProvider("BC")).build(ca.keyPair.private) - return JcaX509CRLConverter().setProvider(Crypto.findProvider("BC")).getCRL(builder.build(signer)) + private fun createServerCRL(issuer: CertificateAndKeyPair, + endpoint: String, + indirect: Boolean, + revokedCerts: List): X509CRL { + logger.info("Generating CRL for /$endpoint: ${revokedCerts.map { it.serialNumber }}") + return createCRL( + issuer, + revokedCerts, + issuingDistPoint = URI("http://$hostAndPort/crl/$endpoint"), + indirect = indirect + ) } override fun close() { - println("Shutting down network management web services...") server.stop() server.join() } @@ -159,8 +150,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { @Path(NODE_CRL) @Produces("application/pkcs7-crl") fun getNodeCRL(): Response { - return Response.ok(crlServer.createRevocationList( - SIGNATURE_ALGORITHM, + crlServer.delay?.toMillis()?.let(Thread::sleep) + return Response.ok(crlServer.createServerCRL( crlServer.intermediateCa, NODE_CRL, false, @@ -179,8 +170,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { @Path(INTERMEDIATE_CRL) @Produces("application/pkcs7-crl") fun getIntermediateCRL(): Response { - return Response.ok(crlServer.createRevocationList( - SIGNATURE_ALGORITHM, + crlServer.delay?.toMillis()?.let(Thread::sleep) + return Response.ok(crlServer.createServerCRL( crlServer.rootCa, INTERMEDIATE_CRL, false, @@ -192,11 +183,11 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { @Path(EMPTY_CRL) @Produces("application/pkcs7-crl") fun getEmptyCRL(): Response { - return Response.ok(crlServer.createRevocationList( - SIGNATURE_ALGORITHM, + return Response.ok(crlServer.createServerCRL( crlServer.rootCa, EMPTY_CRL, - true, emptyList() + true, + emptyList() ).encoded).build() } } diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalTestUtils.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalTestUtils.kt index 1dbf7249cb..ab080c3d7e 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalTestUtils.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalTestUtils.kt @@ -42,6 +42,7 @@ import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.SchemaMigration +import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.serialization.internal.amqp.AMQP_ENABLED import net.corda.testing.core.ALICE_NAME @@ -52,6 +53,8 @@ import java.io.IOException import java.net.ServerSocket import java.nio.file.Path import java.security.KeyPair +import java.security.cert.X509CRL +import java.security.cert.X509Certificate import java.util.* import java.util.jar.JarOutputStream import java.util.jar.Manifest @@ -147,6 +150,12 @@ fun p2pSslOptions(path: Path, name: CordaX500Name = CordaX500Name("MegaCorp", "L return sslConfig } +fun fixedCrlSource(crls: Set): CrlSource { + return object : CrlSource { + override fun fetch(certificate: X509Certificate): Set = crls + } +} + /** This is the same as the deprecated [WireTransaction] c'tor but avoids the deprecation warning. */ @SuppressWarnings("LongParameterList") fun createWireTransaction(inputs: List, diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt index 467b54ea22..f2775e1878 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt @@ -1,6 +1,20 @@ +@file:Suppress("LongParameterList") + package net.corda.testing.internal.vault -import net.corda.core.contracts.* +import net.corda.core.contracts.Amount +import net.corda.core.contracts.AttachmentConstraint +import net.corda.core.contracts.AutomaticPlaceholderConstraint +import net.corda.core.contracts.BelongsToContract +import net.corda.core.contracts.CommandAndState +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.FungibleAsset +import net.corda.core.contracts.Issued +import net.corda.core.contracts.LinearState +import net.corda.core.contracts.PartyAndReference +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.TransactionState +import net.corda.core.contracts.UniqueIdentifier import net.corda.core.crypto.Crypto import net.corda.core.crypto.SignatureMetadata import net.corda.core.identity.AbstractParty @@ -19,9 +33,7 @@ import net.corda.finance.contracts.asset.Cash import net.corda.finance.contracts.asset.Obligation import net.corda.finance.contracts.asset.OnLedgerAsset import net.corda.finance.workflows.asset.CashUtils -import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyState -import net.corda.testing.core.DummyCommandData import net.corda.testing.core.TestIdentity import net.corda.testing.core.dummyCommand import net.corda.testing.core.singleIdentity @@ -32,6 +44,7 @@ import java.time.Duration import java.time.Instant import java.time.Instant.now import java.util.* +import kotlin.math.floor /** * The service hub should provide at least a key management service and a storage service. @@ -46,7 +59,7 @@ class VaultFiller @JvmOverloads constructor( private val rngFactory: () -> Random = { Random(0L) }) { companion object { fun calculateRandomlySizedAmounts(howMuch: Amount, min: Int, max: Int, rng: Random): LongArray { - val numSlots = min + Math.floor(rng.nextDouble() * (max - min)).toInt() + val numSlots = min + floor(rng.nextDouble() * (max - min)).toInt() val baseSize = howMuch.quantity / numSlots check(baseSize > 0) { baseSize } @@ -79,31 +92,18 @@ class VaultFiller @JvmOverloads constructor( issuerServices: ServiceHub = services, participants: List = emptyList(), includeMe: Boolean = true): Vault { - val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey - val me = AnonymousParty(myKey) - val participantsToUse = if (includeMe) participants.plus(me) else participants - - val transactions: List = dealIds.map { - // Issue a deal state - val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { - addOutputState(DummyDealContract.State(ref = it, participants = participantsToUse), DUMMY_DEAL_PROGRAM_ID) - addCommand(dummyCommand()) - } - val stx = issuerServices.signInitialTransaction(dummyIssue) - return@map services.addSignature(stx, defaultNotary.publicKey) + return fillWithTestStates( + txCount = dealIds.size, + participants = participants, + includeMe = includeMe, + services = issuerServices + ) { participantsToUse, txIndex, _ -> + DummyDealContract.State(ref = dealIds[txIndex], participants = participantsToUse) } - val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE - services.recordTransactions(statesToRecord, transactions) - // Get all the StateAndRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - - return Vault(states) } @JvmOverloads - fun fillWithSomeTestLinearStates(numberToCreate: Int, + fun fillWithSomeTestLinearStates(txCount: Int, externalId: String? = null, participants: List = emptyList(), uniqueIdentifier: UniqueIdentifier? = null, @@ -113,81 +113,41 @@ class VaultFiller @JvmOverloads constructor( linearTimestamp: Instant = now(), constraint: AttachmentConstraint = AutomaticPlaceholderConstraint, includeMe: Boolean = true): Vault { - val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey - val me = AnonymousParty(myKey) - val issuerKey = defaultNotary.keyPair - val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) - val participantsToUse = if (includeMe) participants.plus(me) else participants - val transactions: List = (1..numberToCreate).map { - // Issue a Linear state - val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { - addOutputState(DummyLinearContract.State( - linearId = uniqueIdentifier ?: UniqueIdentifier(externalId), - participants = participantsToUse, - linearString = linearString, - linearNumber = linearNumber, - linearBoolean = linearBoolean, - linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID, - constraint = constraint) - addCommand(dummyCommand()) - } - return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata) + return fillWithTestStates(txCount, 1, participants, constraint, includeMe) { participantsToUse, _, _ -> + DummyLinearContract.State( + linearId = uniqueIdentifier ?: UniqueIdentifier(externalId), + participants = participantsToUse, + linearString = linearString, + linearNumber = linearNumber, + linearBoolean = linearBoolean, + linearTimestamp = linearTimestamp + ) } - val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE - services.recordTransactions(statesToRecord, transactions) - // Get all the StateAndRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - - return Vault(states) } @JvmOverloads - fun fillWithSomeTestLinearAndDealStates(numberToCreate: Int, + fun fillWithSomeTestLinearAndDealStates(txCount: Int, externalId: String? = null, participants: List = emptyList(), linearString: String = "", linearNumber: Long = 0L, linearBoolean: Boolean = false, - linearTimestamp: Instant = now()): Vault { - val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey - val me = AnonymousParty(myKey) - val issuerKey = defaultNotary.keyPair - val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) - val transactions: List = (1..numberToCreate).map { - val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { - // Issue a Linear state - addOutputState(DummyLinearContract.State( + linearTimestamp: Instant = now()): Vault { + return fillWithTestStates(txCount, 2, participants) { participantsToUse, _, stateIndex -> + when (stateIndex) { + 0 -> DummyLinearContract.State( linearId = UniqueIdentifier(externalId), - participants = participants.plus(me), + participants = participantsToUse, linearString = linearString, linearNumber = linearNumber, linearBoolean = linearBoolean, - linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID) - // Issue a Deal state - addOutputState(DummyDealContract.State(ref = "test ref", participants = participants.plus(me)), DUMMY_DEAL_PROGRAM_ID) - addCommand(dummyCommand()) + linearTimestamp = linearTimestamp + ) + else -> DummyDealContract.State(ref = "test ref", participants = participantsToUse) } - return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata) } - services.recordTransactions(transactions) - // Get all the StateAndRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - return Vault(states) } - @JvmOverloads - fun fillWithSomeTestCash(howMuch: Amount, - issuerServices: ServiceHub, - thisManyStates: Int, - issuedBy: PartyAndReference, - owner: AbstractParty? = null, - rng: Random? = null, - statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT) = fillWithSomeTestCash(howMuch, issuerServices, thisManyStates, thisManyStates, issuedBy, owner, rng, statesToRecord) - /** * Creates a random set of between (by default) 3 and 10 cash states that add up to the given amount and adds them * to the vault. This is intended for unit tests. By default the cash is owned by the legal @@ -196,14 +156,15 @@ class VaultFiller @JvmOverloads constructor( * @param issuerServices service hub of the issuer node, which will be used to sign the transaction. * @return a vault object that represents the generated states (it will NOT be the full vault from the service hub!). */ + @JvmOverloads fun fillWithSomeTestCash(howMuch: Amount, issuerServices: ServiceHub, atLeastThisManyStates: Int, - atMostThisManyStates: Int, issuedBy: PartyAndReference, owner: AbstractParty? = null, rng: Random? = null, - statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault { + statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT, + atMostThisManyStates: Int = atLeastThisManyStates): Vault { val amounts = calculateRandomlySizedAmounts(howMuch, atLeastThisManyStates, atMostThisManyStates, rng ?: rngFactory()) // We will allocate one state to one transaction, for simplicities sake. val cash = Cash() @@ -212,39 +173,46 @@ class VaultFiller @JvmOverloads constructor( cash.generateIssue(issuance, Amount(pennies, Issued(issuedBy, howMuch.token)), owner ?: services.myInfo.singleIdentity(), altNotary) return@map issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) } - services.recordTransactions(statesToRecord, transactions) - // Get all the StateRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - - return Vault(states) + return recordTransactions(transactions, statesToRecord) } /** * Records a dummy state in the Vault (useful for creating random states when testing vault queries) */ - fun fillWithDummyState(participants: List = listOf(services.myInfo.singleIdentity())) : Vault { - val outputState = TransactionState( - data = DummyState(Random().nextInt(), participants = participants), - contract = DummyContract.PROGRAM_ID, - notary = defaultNotary.party - ) - val participantKeys : List = participants.map { it.owningKey } - val builder = TransactionBuilder() - .addOutputState(outputState) - .addCommand(DummyCommandData, participantKeys) - val stxn = services.signInitialTransaction(builder) - services.recordTransactions(stxn) - return Vault(setOf(stxn.tx.outRef(0))) + fun fillWithDummyState(participants: List = listOf(services.myInfo.singleIdentity())): Vault { + return fillWithTestStates(participants = participants) { participantsToUse, _, _ -> + DummyState(Random().nextInt(), participants = participantsToUse) + } } - /** - * Puts together an issuance transaction for the specified amount that starts out being owned by the given pubkey. - */ - fun generateCommoditiesIssue(tx: TransactionBuilder, amount: Amount>, owner: AbstractParty, notary: Party) - = OnLedgerAsset.generateIssue(tx, TransactionState(CommodityState(amount, owner), Obligation.PROGRAM_ID, notary), Obligation.Commands.Issue()) - + fun fillWithTestStates(txCount: Int = 1, + statesPerTx: Int = 1, + participants: List = emptyList(), + constraint: AttachmentConstraint = AutomaticPlaceholderConstraint, + includeMe: Boolean = true, + services: ServiceHub = this.services, + genOutputState: (participantsToUse: List, txIndex: Int, stateIndex: Int) -> T): Vault { + val issuerKey = defaultNotary.keyPair + val signatureMetadata = SignatureMetadata( + services.myInfo.platformVersion, + Crypto.findSignatureScheme(issuerKey.public).schemeNumberID + ) + val participantsToUse = if (includeMe) { + participants + AnonymousParty(this.services.myInfo.chooseIdentity().owningKey) + } else { + participants + } + val transactions = Array(txCount) { txIndex -> + val builder = TransactionBuilder(notary = defaultNotary.party) + repeat(statesPerTx) { stateIndex -> + builder.addOutputState(genOutputState(participantsToUse, txIndex, stateIndex), constraint) + } + builder.addCommand(dummyCommand()) + services.signInitialTransaction(builder).withAdditionalSignature(issuerKey, signatureMetadata) + } + val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE + return recordTransactions(transactions.asList(), statesToRecord) + } /** * @@ -257,13 +225,16 @@ class VaultFiller @JvmOverloads constructor( val me = AnonymousParty(myKey) val issuance = TransactionBuilder(null as Party?) - generateCommoditiesIssue(issuance, Amount(amount.quantity, Issued(issuedBy, amount.token)), me, altNotary) + OnLedgerAsset.generateIssue( + issuance, + TransactionState(CommodityState(Amount(amount.quantity, Issued(issuedBy, amount.token)), me), Obligation.PROGRAM_ID, altNotary), + Obligation.Commands.Issue() + ) val transaction = issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) - services.recordTransactions(transaction) - return Vault(setOf(transaction.tx.outRef(0))) + return recordTransactions(listOf(transaction)) } - private fun consume(states: List>) { + fun consumeStates(states: Iterable>) { // Create a txn consuming different contract types states.forEach { val builder = TransactionBuilder(notary = altNotary).apply { @@ -300,10 +271,11 @@ class VaultFiller @JvmOverloads constructor( } } - fun consumeDeals(dealStates: List>) = consume(dealStates) - fun consumeLinearStates(linearStates: List>) = consume(linearStates) + fun consumeDeals(dealStates: List>) = consumeStates(dealStates) + fun consumeLinearStates(linearStates: List>) = consumeStates(linearStates) fun evolveLinearStates(linearStates: List>) = consumeAndProduce(linearStates) fun evolveLinearState(linearState: StateAndRef): StateAndRef = consumeAndProduce(linearState) + /** * Consume cash, sending any change to the default identity for this node. Only suitable for use in test scenarios, * where nodes have a default identity. @@ -319,6 +291,16 @@ class VaultFiller @JvmOverloads constructor( services.recordTransactions(spendTx) return update.getOrThrow(Duration.ofSeconds(3)) } + + private fun recordTransactions(transactions: Iterable, + statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault { + services.recordTransactions(statesToRecord, transactions) + // Get all the StateAndRefs of all the generated transactions. + val states = transactions.flatMap { stx -> + stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } + } + return Vault(states) + } } @@ -344,4 +326,3 @@ data class CommodityState( override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Obligation.Commands.Move(), copy(owner = newOwner)) } -