From 0a617097be60a2a2cc83dc6ff8d67b9d3eb35bbd Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Tue, 2 May 2023 14:38:56 +0100 Subject: [PATCH 01/12] ENT-9806: Prevent Netty threads being blocked due to unresponsive CRL endpoints --- .../crypto/internal/ProviderMapTest.kt | 29 + .../internal/crypto/X509UtilitiesTest.kt | 75 ++- .../nodeapi/internal/ArtemisTcpTransport.kt | 114 ++-- .../corda/nodeapi/internal/NodeApiUtils.kt | 32 + .../internal/bridging/AMQPBridgeManager.kt | 20 +- .../nodeapi/internal/crypto/X509Utilities.kt | 38 +- .../protonwrapper/netty/AMQPClient.kt | 48 +- .../protonwrapper/netty/AMQPServer.kt | 67 +- .../netty/AllowAllRevocationChecker.kt | 2 + .../protonwrapper/netty/RevocationConfig.kt | 15 - .../internal/protonwrapper/netty/SSLHelper.kt | 117 ++-- .../revocation/CertDistPointCrlSource.kt | 90 +-- .../internal/crypto/TlsDiffAlgorithmsTest.kt | 9 +- .../internal/crypto/TlsDiffProtocolsTest.kt | 10 +- .../protonwrapper/netty/SSLHelperTest.kt | 24 +- .../revocation/CertDistPointCrlSourceTest.kt | 13 +- .../revocation/CordaRevocationCheckerTest.kt | 8 +- .../internal/revocation/RevocationTest.kt | 113 ++-- node/build.gradle | 2 - .../node/amqp/AMQPClientSslErrorsTest.kt | 27 +- .../CertificateRevocationListNodeTests.kt | 616 +++++++++--------- .../net/corda/node/amqp/ProtonWrapperTests.kt | 27 +- .../net/corda/node/internal/AbstractNode.kt | 19 +- .../internal/artemis/BrokerJaasLoginModule.kt | 8 +- .../artemis/CertificateChainCheckPolicy.kt | 37 -- .../messaging/ArtemisMessagingServer.kt | 35 +- .../messaging/NodeNettyAcceptorFactory.kt | 139 +++- .../services/rpc/RpcBrokerConfiguration.kt | 2 +- .../net/corda/testing/core/TestUtils.kt | 67 +- .../node/internal/network/CrlServer.kt | 75 +-- .../testing/internal/InternalTestUtils.kt | 9 + 31 files changed, 1110 insertions(+), 777 deletions(-) create mode 100644 core-tests/src/test/kotlin/net/corda/coretests/crypto/internal/ProviderMapTest.kt create mode 100644 node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeApiUtils.kt rename node/src/test/kotlin/net/corda/node/internal/artemis/RevocationCheckTest.kt => node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/RevocationTest.kt (59%) diff --git a/core-tests/src/test/kotlin/net/corda/coretests/crypto/internal/ProviderMapTest.kt b/core-tests/src/test/kotlin/net/corda/coretests/crypto/internal/ProviderMapTest.kt new file mode 100644 index 0000000000..11a7e9ed85 --- /dev/null +++ b/core-tests/src/test/kotlin/net/corda/coretests/crypto/internal/ProviderMapTest.kt @@ -0,0 +1,29 @@ +package net.corda.coretests.crypto.internal + +import net.corda.coretesting.internal.DEV_ROOT_CA +import net.corda.testing.core.createCRL +import org.assertj.core.api.Assertions.assertThatIllegalArgumentException +import org.junit.Test + +class ProviderMapTest { + // https://github.com/corda/corda/pull/3997 + @Test(timeout = 300_000) + fun `verify CRL algorithms`() { + val crl = createCRL( + issuer = DEV_ROOT_CA, + revokedCerts = emptyList(), + signatureAlgorithm = "SHA256withECDSA" + ) + // This should pass. + crl.verify(DEV_ROOT_CA.keyPair.public) + + // Try changing the algorithm to EC will fail. + assertThatIllegalArgumentException().isThrownBy { + createCRL( + issuer = DEV_ROOT_CA, + revokedCerts = emptyList(), + signatureAlgorithm = "EC" + ) + }.withMessage("Unknown signature type requested: EC") + } +} diff --git a/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto/X509UtilitiesTest.kt b/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto/X509UtilitiesTest.kt index 9ce80590a1..9affc6a0b1 100644 --- a/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto/X509UtilitiesTest.kt +++ b/node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto/X509UtilitiesTest.kt @@ -21,14 +21,29 @@ import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.core.utilities.days import net.corda.core.utilities.hours -import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme +import net.corda.coretesting.internal.NettyTestClient +import net.corda.coretesting.internal.NettyTestHandler +import net.corda.coretesting.internal.NettyTestServer +import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.createDevNodeCa +import net.corda.nodeapi.internal.crypto.CertificateType +import net.corda.nodeapi.internal.crypto.X509CertificateFactory +import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_IDENTITY_SIGNATURE_SCHEME import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME +import net.corda.nodeapi.internal.crypto.checkValidity +import net.corda.nodeapi.internal.crypto.getSupportedKey +import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore +import net.corda.nodeapi.internal.crypto.save +import net.corda.nodeapi.internal.crypto.toBc +import net.corda.nodeapi.internal.crypto.x509 +import net.corda.nodeapi.internal.crypto.x509Certificates import net.corda.nodeapi.internal.installDevNodeCaCertPath -import net.corda.nodeapi.internal.protonwrapper.netty.init +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import net.corda.nodeapi.internal.registerDevP2pCertificates +import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme import net.corda.serialization.internal.AllWhitelist import net.corda.serialization.internal.SerializationContextImpl import net.corda.serialization.internal.SerializationFactoryImpl @@ -37,25 +52,16 @@ import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.core.TestIdentity import net.corda.testing.driver.internal.incrementalPortAllocation -import net.corda.coretesting.internal.NettyTestClient -import net.corda.coretesting.internal.NettyTestHandler -import net.corda.coretesting.internal.NettyTestServer -import net.corda.testing.internal.createDevIntermediateCaCertPath -import net.corda.coretesting.internal.stubs.CertificateStoreStubs -import net.corda.nodeapi.internal.crypto.CertificateType -import net.corda.nodeapi.internal.crypto.X509CertificateFactory -import net.corda.nodeapi.internal.crypto.X509Utilities -import net.corda.nodeapi.internal.crypto.checkValidity -import net.corda.nodeapi.internal.crypto.getSupportedKey -import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore -import net.corda.nodeapi.internal.crypto.save -import net.corda.nodeapi.internal.crypto.toBc -import net.corda.nodeapi.internal.crypto.x509 -import net.corda.nodeapi.internal.crypto.x509Certificates import net.corda.testing.internal.IS_OPENJ9 +import net.corda.testing.internal.createDevIntermediateCaCertPath import net.i2p.crypto.eddsa.EdDSAPrivateKey import org.assertj.core.api.Assertions.assertThat -import org.bouncycastle.asn1.x509.* +import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier +import org.bouncycastle.asn1.x509.BasicConstraints +import org.bouncycastle.asn1.x509.CRLDistPoint +import org.bouncycastle.asn1.x509.Extension +import org.bouncycastle.asn1.x509.KeyUsage +import org.bouncycastle.asn1.x509.SubjectKeyIdentifier import org.bouncycastle.jcajce.provider.asymmetric.edec.BCEdDSAPrivateKey import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey import org.junit.Assume @@ -74,10 +80,19 @@ import java.security.PrivateKey import java.security.cert.CertPath import java.security.cert.X509Certificate import java.util.* -import javax.net.ssl.* +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLParameters +import javax.net.ssl.SSLServerSocket +import javax.net.ssl.SSLSocket import javax.security.auth.x500.X500Principal import kotlin.concurrent.thread -import kotlin.test.* +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.test.fail class X509UtilitiesTest { private companion object { @@ -295,15 +310,10 @@ class X509UtilitiesTest { sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) sslConfig.createTrustStore(rootCa.certificate) - val keyStore = sslConfig.keyStore.get() - val trustStore = sslConfig.trustStore.get() - val context = SSLContext.getInstance("TLS") - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) + val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get()) val keyManagers = keyManagerFactory.keyManagers - val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get()) val trustManagers = trustMgrFactory.trustManagers context.init(keyManagers, trustManagers, newSecureRandom()) @@ -388,15 +398,8 @@ class X509UtilitiesTest { sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) sslConfig.createTrustStore(rootCa.certificate) - val keyStore = sslConfig.keyStore.get() - val trustStore = sslConfig.trustStore.get() - - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) - - val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustManagerFactory.init(trustStore) - + val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get()) + val trustManagerFactory = trustManagerFactory(sslConfig.trustStore.get()) val sslServerContext = SslContextBuilder .forServer(keyManagerFactory) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt index 6183cfe818..c164a21342 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt @@ -1,16 +1,18 @@ +@file:Suppress("LongParameterList") + package net.corda.nodeapi.internal import net.corda.core.messaging.ClientRpcSslOptions import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.BrokerRpcSslOptions -import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.SslConfiguration +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import org.apache.activemq.artemis.api.core.TransportConfiguration import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants -import java.nio.file.Path +import javax.net.ssl.TrustManagerFactory @Suppress("LongParameterList") class ArtemisTcpTransport { @@ -23,6 +25,7 @@ class ArtemisTcpTransport { val TLS_VERSIONS = listOf("TLSv1.2") const val SSL_HANDSHAKE_TIMEOUT_NAME = "Corda-SSLHandshakeTimeout" + const val TRUST_MANAGER_FACTORY_NAME = "Corda-TrustManagerFactory" const val TRACE_NAME = "Corda-Trace" const val THREAD_POOL_NAME_NAME = "Corda-ThreadPoolName" @@ -30,7 +33,6 @@ class ArtemisTcpTransport { // Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop. // It does not use AMQP messages for its own messages e.g. topology and heartbeats. private const val P2P_PROTOCOLS = "CORE,AMQP" - private const val RPC_PROTOCOLS = "CORE" private fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort, protocols: String) = mapOf( @@ -39,46 +41,35 @@ class ArtemisTcpTransport { TransportConstants.PORT_PROP_NAME to hostAndPort.port, TransportConstants.PROTOCOLS_PROP_NAME to protocols, TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null), - TransportConstants.REMOTING_THREADS_PROPNAME to (if (nodeSerializationEnv != null) -1 else 1), // turn off direct delivery in Artemis - this is latency optimisation that can lead to //hick-ups under high load (CORDA-1336) TransportConstants.DIRECT_DELIVER to false) - private val defaultSSLOptions = mapOf( - TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","), - TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to TLS_VERSIONS.joinToString(",")) - private fun SslConfiguration.addToTransportOptions(options: MutableMap) { + if (keyStore != null || trustStore != null) { + options[TransportConstants.SSL_ENABLED_PROP_NAME] = true + options[TransportConstants.NEED_CLIENT_AUTH_PROP_NAME] = true + } keyStore?.let { with (it) { path.requireOnDefaultFileSystem() - options.putAll(get().toKeyStoreTransportOptions(path)) + options[TransportConstants.KEYSTORE_PROVIDER_PROP_NAME] = "JKS" + options[TransportConstants.KEYSTORE_PATH_PROP_NAME] = path + options[TransportConstants.KEYSTORE_PASSWORD_PROP_NAME] = get().password } } trustStore?.let { with (it) { path.requireOnDefaultFileSystem() - options.putAll(get().toTrustStoreTransportOptions(path)) + options[TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME] = "JKS" + options[TransportConstants.TRUSTSTORE_PATH_PROP_NAME] = path + options[TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME] = get().password } } options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER options[SSL_HANDSHAKE_TIMEOUT_NAME] = handshakeTimeout ?: DEFAULT_SSL_HANDSHAKE_TIMEOUT } - private fun CertificateStore.toKeyStoreTransportOptions(path: Path) = mapOf( - TransportConstants.SSL_ENABLED_PROP_NAME to true, - TransportConstants.KEYSTORE_PROVIDER_PROP_NAME to "JKS", - TransportConstants.KEYSTORE_PATH_PROP_NAME to path, - TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to password, - TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true) - - private fun CertificateStore.toTrustStoreTransportOptions(path: Path) = mapOf( - TransportConstants.SSL_ENABLED_PROP_NAME to true, - TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME to "JKS", - TransportConstants.TRUSTSTORE_PATH_PROP_NAME to path, - TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME to password, - TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true) - private fun ClientRpcSslOptions.toTransportOptions() = mapOf( TransportConstants.SSL_ENABLED_PROP_NAME to true, TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME to trustStoreProvider, @@ -94,50 +85,64 @@ class ArtemisTcpTransport { fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, + trustManagerFactory: TrustManagerFactory?, enableSSL: Boolean = true, threadPoolName: String = "P2PServer", - trace: Boolean = false): TransportConfiguration { + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { val options = mutableMapOf() if (enableSSL) { config?.addToTransportOptions(options) } - return createAcceptorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace) + return createAcceptorTransport( + hostAndPort, + P2P_PROTOCOLS, + options, + trustManagerFactory, + enableSSL, + threadPoolName, + trace, + remotingThreads + ) } fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true, threadPoolName: String = "P2PClient", - trace: Boolean = false): TransportConfiguration { + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { val options = mutableMapOf() if (enableSSL) { config?.addToTransportOptions(options) } - return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace) + return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace, remotingThreads) } fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: BrokerRpcSslOptions?, enableSSL: Boolean = true, - trace: Boolean = false): TransportConfiguration { + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { val options = mutableMapOf() if (config != null && enableSSL) { config.keyStorePath.requireOnDefaultFileSystem() options.putAll(config.toTransportOptions()) } - return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCServer", trace) + return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, null, enableSSL, "RPCServer", trace, remotingThreads) } fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort, config: ClientRpcSslOptions?, enableSSL: Boolean = true, - trace: Boolean = false): TransportConfiguration { + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { val options = mutableMapOf() if (config != null && enableSSL) { config.trustStorePath.requireOnDefaultFileSystem() options.putAll(config.toTransportOptions()) } - return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace) + return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace, remotingThreads) } fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort, @@ -145,25 +150,45 @@ class ArtemisTcpTransport { trace: Boolean = false): TransportConfiguration { val options = mutableMapOf() config.addToTransportOptions(options) - return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace) + return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace, null) } fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration, - trace: Boolean = false): TransportConfiguration { + trace: Boolean = false, + remotingThreads: Int? = null): TransportConfiguration { val options = mutableMapOf() config.addToTransportOptions(options) - return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCServer", trace) + return createAcceptorTransport( + hostAndPort, + RPC_PROTOCOLS, + options, + trustManagerFactory(requireNotNull(config.trustStore).get()), + true, + "Internal-RPCServer", + trace, + remotingThreads + ) } private fun createAcceptorTransport(hostAndPort: NetworkHostAndPort, protocols: String, options: MutableMap, + trustManagerFactory: TrustManagerFactory?, enableSSL: Boolean, threadPoolName: String, - trace: Boolean): TransportConfiguration { + trace: Boolean, + remotingThreads: Int?): TransportConfiguration { // Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections options[TransportConstants.HANDSHAKE_TIMEOUT] = 0 + if (trustManagerFactory != null) { + // NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use + // more customised instances which use our revocation checkers, which we pass directly into NodeNettyAcceptorFactory. + // + // This, however, requires copying a lot of code from NettyAcceptor into NodeNettyAcceptor. The version of Artemis in + // Corda 4.9 solves this problem by introducing a "trustManagerFactoryPlugin" config option. + options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory + } return createTransport( "net.corda.node.services.messaging.NodeNettyAcceptorFactory", hostAndPort, @@ -171,7 +196,8 @@ class ArtemisTcpTransport { options, enableSSL, threadPoolName, - trace + trace, + remotingThreads ) } @@ -180,7 +206,8 @@ class ArtemisTcpTransport { options: MutableMap, enableSSL: Boolean, threadPoolName: String, - trace: Boolean): TransportConfiguration { + trace: Boolean, + remotingThreads: Int?): TransportConfiguration { return createTransport( "net.corda.node.services.messaging.NodeNettyConnectorFactory", hostAndPort, @@ -188,7 +215,8 @@ class ArtemisTcpTransport { options, enableSSL, threadPoolName, - trace + trace, + remotingThreads ) } @@ -198,11 +226,15 @@ class ArtemisTcpTransport { options: MutableMap, enableSSL: Boolean, threadPoolName: String, - trace: Boolean): TransportConfiguration { + trace: Boolean, + remotingThreads: Int?): TransportConfiguration { options += defaultArtemisOptions(hostAndPort, protocols) if (enableSSL) { - options += defaultSSLOptions + options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",") + options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",") } + // By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357) + options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1 options[THREAD_POOL_NAME_NAME] = threadPoolName options[TRACE_NAME] = trace return TransportConfiguration(className, options) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeApiUtils.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeApiUtils.kt new file mode 100644 index 0000000000..65d60ab38d --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeApiUtils.kt @@ -0,0 +1,32 @@ +@file:Suppress("LongParameterList", "MagicNumber") + +package net.corda.nodeapi.internal + +import io.netty.util.concurrent.DefaultThreadFactory +import net.corda.core.utilities.seconds +import java.time.Duration +import java.util.concurrent.BlockingQueue +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.TimeUnit + +/** + * Creates a [ThreadPoolExecutor] which will use a maximum of [maxPoolSize] threads at any given time and will by default idle down to 0 + * threads. + */ +fun namedThreadPoolExecutor(maxPoolSize: Int, + corePoolSize: Int = 0, + idleKeepAlive: Duration = 30.seconds, + workQueue: BlockingQueue = LinkedBlockingQueue(), + poolName: String = "pool", + daemonThreads: Boolean = false, + threadPriority: Int = Thread.NORM_PRIORITY): ThreadPoolExecutor { + return ThreadPoolExecutor( + corePoolSize, + maxPoolSize, + idleKeepAlive.toNanos(), + TimeUnit.NANOSECONDS, + workQueue, + DefaultThreadFactory(poolName, daemonThreads, threadPriority) + ) +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt index ee09e640ad..deb6ef999a 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt @@ -100,7 +100,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, private class AMQPBridge(val sourceX500Name: String, val queueName: String, val targets: List, - val legalNames: Set, + val allowedRemoteLegalNames: Set, private val amqpConfig: AMQPConfiguration, sharedEventGroup: EventLoopGroup, private val artemis: ArtemisSessionProvider, @@ -116,7 +116,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, MDC.put("queueName", queueName) MDC.put("source", amqpConfig.sourceX500Name) MDC.put("targets", targets.joinToString(separator = ";") { it.toString() }) - MDC.put("legalNames", legalNames.joinToString(separator = ";") { it.toString() }) + MDC.put("allowedRemoteLegalNames", allowedRemoteLegalNames.joinToString(separator = ";") { it.toString() }) MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString()) block() } finally { @@ -134,7 +134,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) } - val amqpClient = AMQPClient(targets, legalNames, amqpConfig, sharedThreadPool = sharedEventGroup) + val amqpClient = AMQPClient(targets, allowedRemoteLegalNames, amqpConfig, sharedThreadPool = sharedEventGroup) private var session: ClientSession? = null private var consumer: ClientConsumer? = null private var connectedSubscription: Subscription? = null @@ -231,7 +231,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, } ArtemisState.STOPPING } - bridgeMetricsService?.bridgeDisconnected(targets, legalNames) + bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames) connectedSubscription?.unsubscribe() connectedSubscription = null // Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first. @@ -243,7 +243,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, if (connected) { logInfoWithMDC("Bridge Connected") - bridgeMetricsService?.bridgeConnected(targets, legalNames) + bridgeMetricsService?.bridgeConnected(targets, allowedRemoteLegalNames) if (bridgeConnectionTTLSeconds > 0) { // AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS, @@ -286,7 +286,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, logInfoWithMDC("Bridge Disconnected") amqpRestartEvent?.cancel(false) if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) { - bridgeMetricsService?.bridgeDisconnected(targets, legalNames) + bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames) } artemis(ArtemisState.STOPPING) { precedingState: ArtemisState -> logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected") @@ -418,10 +418,10 @@ open class AMQPBridgeManager(keyStore: CertificateStore, properties[key] = value } } - logDebugWithMDC { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" } + logDebugWithMDC { "Bridged Send to ${allowedRemoteLegalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" } val peerInbox = translateLocalQueueToInboxAddress(queueName) val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox, - legalNames.first().toString(), + allowedRemoteLegalNames.first().toString(), properties) sendableMessage.onComplete.then { logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" } @@ -486,7 +486,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, queueNamesToBridgesMap.remove(queueName) } bridge.stop() - bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.legalNames) + bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.allowedRemoteLegalNames) } } } @@ -498,7 +498,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, val bridges = queueNamesToBridgesMap[queueName]?.toList() destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList()) bridges?.map { - it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.legalNames.toList(), serviceAddress = false) + it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.allowedRemoteLegalNames.toList(), serviceAddress = false) }?.toMap() ?: emptyMap() } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt index 9bc577e831..79ae834a16 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt @@ -5,16 +5,37 @@ package net.corda.nodeapi.internal.crypto import net.corda.core.CordaOID import net.corda.core.crypto.Crypto import net.corda.core.crypto.newSecureRandom -import net.corda.core.internal.* +import net.corda.core.internal.CertRole +import net.corda.core.internal.SignedDataWithCert +import net.corda.core.internal.reader +import net.corda.core.internal.signWithCert +import net.corda.core.internal.uncheckedCast +import net.corda.core.internal.validate +import net.corda.core.internal.writer import net.corda.core.utilities.days import net.corda.core.utilities.millis import net.corda.core.utilities.toHex import net.corda.nodeapi.internal.protonwrapper.netty.distributionPointsToString -import org.bouncycastle.asn1.* +import org.bouncycastle.asn1.ASN1EncodableVector +import org.bouncycastle.asn1.ASN1ObjectIdentifier +import org.bouncycastle.asn1.ASN1Sequence +import org.bouncycastle.asn1.DERSequence +import org.bouncycastle.asn1.DERUTF8String import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.style.BCStyle -import org.bouncycastle.asn1.x509.* +import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier +import org.bouncycastle.asn1.x509.BasicConstraints +import org.bouncycastle.asn1.x509.CRLDistPoint +import org.bouncycastle.asn1.x509.DistributionPoint +import org.bouncycastle.asn1.x509.DistributionPointName import org.bouncycastle.asn1.x509.Extension +import org.bouncycastle.asn1.x509.GeneralName +import org.bouncycastle.asn1.x509.GeneralNames +import org.bouncycastle.asn1.x509.KeyPurposeId +import org.bouncycastle.asn1.x509.KeyUsage +import org.bouncycastle.asn1.x509.NameConstraints +import org.bouncycastle.asn1.x509.SubjectKeyIdentifier +import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo import org.bouncycastle.cert.X509CertificateHolder import org.bouncycastle.cert.X509v3CertificateBuilder import org.bouncycastle.cert.bc.BcX509ExtensionUtils @@ -32,8 +53,13 @@ import java.nio.file.Path import java.security.KeyPair import java.security.PublicKey import java.security.SignatureException -import java.security.cert.* +import java.security.cert.CertPath import java.security.cert.Certificate +import java.security.cert.CertificateException +import java.security.cert.CertificateFactory +import java.security.cert.TrustAnchor +import java.security.cert.X509CRL +import java.security.cert.X509Certificate import java.time.Duration import java.time.Instant import java.time.temporal.ChronoUnit @@ -359,7 +385,7 @@ object X509Utilities { private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) { if (crlDistPoint != null) { - val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint))) + val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier)) val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(crlIssuer)) } @@ -379,6 +405,8 @@ object X509Utilities { bytes[0] = bytes[0].and(0x3F).or(0x40) return BigInteger(bytes) } + + fun toGeneralNames(string: String, tag: Int = GeneralName.directoryName): GeneralNames = GeneralNames(GeneralName(tag, string)) } // Assuming cert type to role is 1:1 diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt index 3e8fc485ba..3c18830147 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt @@ -27,15 +27,14 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME import net.corda.nodeapi.internal.requireMessageSize +import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource import rx.Observable import rx.subjects.PublishSubject import java.lang.Long.min import java.net.InetSocketAddress -import java.security.cert.CertPathValidatorException +import java.util.concurrent.ExecutorService import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock -import javax.net.ssl.KeyManagerFactory -import javax.net.ssl.TrustManagerFactory import kotlin.concurrent.withLock enum class ProxyVersion { @@ -63,7 +62,8 @@ class AMQPClient(private val targets: List, val allowedRemoteLegalNames: Set, private val configuration: AMQPConfiguration, private val sharedThreadPool: EventLoopGroup? = null, - private val threadPoolName: String = "AMQPClient") : AutoCloseable { + private val threadPoolName: String = "AMQPClient", + private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON) : AutoCloseable { companion object { init { InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) @@ -89,12 +89,12 @@ class AMQPClient(private val targets: List, private var targetIndex = 0 private var currentTarget: NetworkHostAndPort = targets.first() private var retryInterval = MIN_RETRY_INTERVAL - private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker() private val badCertTargets = mutableSetOf() @Volatile private var amqpActive = false @Volatile private var amqpChannelHandler: ChannelHandler? = null + private var sslDelegatedTaskExecutor: ExecutorService? = null val localAddressString: String get() = clientChannel?.localAddress()?.toString() ?: "" @@ -150,17 +150,16 @@ class AMQPClient(private val targets: List, } private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer() { - private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore) + private val trustManagerFactory = trustManagerFactoryWithRevocation( + parent.configuration.trustStore, + parent.configuration.revocationConfig, + parent.distPointCrlSource + ) private val conf = parent.configuration @Volatile private lateinit var amqpChannelHandler: AMQPChannelHandler - init { - keyManagerFactory.init(conf.keyStore) - trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker)) - } - @Suppress("ComplexMethod") override fun initChannel(ch: SocketChannel) { val pipeline = ch.pipeline() @@ -199,10 +198,24 @@ class AMQPClient(private val targets: List, val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration) val target = parent.currentTarget + val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor) val handler = if (parent.configuration.useOpenSsl) { - createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc()) + createClientOpenSslHandler( + target, + parent.allowedRemoteLegalNames, + wrappedKeyManagerFactory, + trustManagerFactory, + ch.alloc(), + delegatedTaskExecutor + ) } else { - createClientSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory) + createClientSslHandler( + target, + parent.allowedRemoteLegalNames, + wrappedKeyManagerFactory, + trustManagerFactory, + delegatedTaskExecutor + ) } handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis() pipeline.addLast("sslHandler", handler) @@ -260,6 +273,7 @@ class AMQPClient(private val targets: List, return } log.info("Connect to: $currentTarget") + sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName) workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY)) started = true restart() @@ -294,6 +308,8 @@ class AMQPClient(private val targets: List, } clientChannel = null workerGroup = null + sslDelegatedTaskExecutor?.shutdown() + sslDelegatedTaskExecutor = null log.info("Stopped connection to $currentTarget") } } @@ -334,6 +350,4 @@ class AMQPClient(private val targets: List, private val _onConnection = PublishSubject.create().toSerialized() val onConnection: Observable get() = _onConnection - - val softFailExceptions: List get() = revocationChecker.softFailExceptions -} \ No newline at end of file +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt index cbeb2562b4..523cde184a 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPServer.kt @@ -21,16 +21,15 @@ import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.requireMessageSize +import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource import org.apache.qpid.proton.engine.Delivery import rx.Observable import rx.subjects.PublishSubject import java.net.BindException import java.net.InetSocketAddress -import java.security.cert.CertPathValidatorException import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ExecutorService import java.util.concurrent.locks.ReentrantLock -import javax.net.ssl.KeyManagerFactory -import javax.net.ssl.TrustManagerFactory import kotlin.concurrent.withLock /** @@ -39,37 +38,34 @@ import kotlin.concurrent.withLock class AMQPServer(val hostName: String, val port: Int, private val configuration: AMQPConfiguration, - private val threadPoolName: String = "AMQPServer") : AutoCloseable { + private val threadPoolName: String = "AMQPServer", + private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON, + private val remotingThreads: Int? = null) : AutoCloseable { companion object { init { InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) } - private const val CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME = "net.corda.nodeapi.amqpserver.NumServerThreads" - private val log = contextLogger() - private val NUM_SERVER_THREADS = Integer.getInteger(CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME, 4) + private val DEFAULT_REMOTING_THREADS = Integer.getInteger("net.corda.nodeapi.amqpserver.NumServerThreads", 4) } private val lock = ReentrantLock() - @Volatile - private var stopping: Boolean = false private var bossGroup: EventLoopGroup? = null private var workerGroup: EventLoopGroup? = null private var serverChannel: Channel? = null - private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker() + private var sslDelegatedTaskExecutor: ExecutorService? = null private val clientChannels = ConcurrentHashMap() private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer() { - private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore) + private val trustManagerFactory = trustManagerFactoryWithRevocation( + parent.configuration.trustStore, + parent.configuration.revocationConfig, + parent.distPointCrlSource + ) private val conf = parent.configuration - init { - keyManagerFactory.init(conf.keyStore.value.internal, conf.keyStore.entryPassword.toCharArray()) - trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker)) - } - override fun initChannel(ch: SocketChannel) { val amqpConfiguration = parent.configuration val pipeline = ch.pipeline() @@ -116,11 +112,12 @@ class AMQPServer(val hostName: String, Pair(createServerSNIOpenSniHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap) } else { val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig) + val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor) val handler = if (amqpConfig.useOpenSsl) { - createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc()) + createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc(), delegatedTaskExecutor) } else { // For javaSSL, SNI matching is handled at key manager level. - createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory) + createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory, delegatedTaskExecutor) } handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout.toMillis() Pair(handler, mapOf(DEFAULT to keyManagerFactory)) @@ -132,8 +129,13 @@ class AMQPServer(val hostName: String, lock.withLock { stop() + sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName) + bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("$threadPoolName-boss", Thread.MAX_PRIORITY)) - workerGroup = NioEventLoopGroup(NUM_SERVER_THREADS, DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY)) + workerGroup = NioEventLoopGroup( + remotingThreads ?: DEFAULT_REMOTING_THREADS, + DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY) + ) val server = ServerBootstrap() // TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux @@ -154,22 +156,19 @@ class AMQPServer(val hostName: String, fun stop() { lock.withLock { - try { - stopping = true - serverChannel?.apply { close() } - serverChannel = null + serverChannel?.close() + serverChannel = null - workerGroup?.shutdownGracefully() - workerGroup?.terminationFuture()?.sync() + workerGroup?.shutdownGracefully() + workerGroup?.terminationFuture()?.sync() + workerGroup = null - bossGroup?.shutdownGracefully() - bossGroup?.terminationFuture()?.sync() + bossGroup?.shutdownGracefully() + bossGroup?.terminationFuture()?.sync() + bossGroup = null - workerGroup = null - bossGroup = null - } finally { - stopping = false - } + sslDelegatedTaskExecutor?.shutdown() + sslDelegatedTaskExecutor = null } } @@ -226,6 +225,4 @@ class AMQPServer(val hostName: String, private val _onConnection = PublishSubject.create().toSerialized() val onConnection: Observable get() = _onConnection - - val softFailExceptions: List get() = revocationChecker.softFailExceptions } \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AllowAllRevocationChecker.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AllowAllRevocationChecker.kt index 538bb17a1e..a853cbffc8 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AllowAllRevocationChecker.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AllowAllRevocationChecker.kt @@ -31,4 +31,6 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() { override fun getSoftFailExceptions(): List { return Collections.emptyList() } + + override fun clone(): AllowAllRevocationChecker = this } \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/RevocationConfig.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/RevocationConfig.kt index 6e8c695387..4e1b4b1930 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/RevocationConfig.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/RevocationConfig.kt @@ -3,9 +3,6 @@ package net.corda.nodeapi.internal.protonwrapper.netty import com.typesafe.config.Config import net.corda.nodeapi.internal.config.ConfigParser import net.corda.nodeapi.internal.config.CustomConfigParser -import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource -import net.corda.nodeapi.internal.revocation.CordaRevocationChecker -import java.security.cert.PKIXRevocationChecker /** * Data structure for controlling the way how Certificate Revocation Lists are handled. @@ -45,18 +42,6 @@ interface RevocationConfig { * Optional [CrlSource] which only makes sense with `mode` = `EXTERNAL_SOURCE` */ val externalCrlSource: CrlSource? - - fun createPKIXRevocationChecker(): PKIXRevocationChecker { - return when (mode) { - Mode.OFF -> AllowAllRevocationChecker - Mode.EXTERNAL_SOURCE -> { - val externalCrlSource = requireNotNull(externalCrlSource) { "externalCrlSource must be specfied for EXTERNAL_SOURCE" } - CordaRevocationChecker(externalCrlSource, softFail = true) - } - Mode.SOFT_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = true) - Mode.HARD_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = false) - } - } } /** diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt index 705fbc2905..7de3e5e302 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt @@ -1,3 +1,5 @@ +@file:Suppress("ComplexMethod", "LongParameterList") + package net.corda.nodeapi.internal.protonwrapper.netty import io.netty.buffer.ByteBufAllocator @@ -18,6 +20,8 @@ import net.corda.nodeapi.internal.ArtemisTcpTransport import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.crypto.toSimpleString import net.corda.nodeapi.internal.crypto.x509 +import net.corda.nodeapi.internal.namedThreadPoolExecutor +import net.corda.nodeapi.internal.revocation.CordaRevocationChecker import org.bouncycastle.asn1.ASN1InputStream import org.bouncycastle.asn1.ASN1Primitive import org.bouncycastle.asn1.DERIA5String @@ -34,10 +38,10 @@ import java.net.URI import java.security.KeyStore import java.security.cert.CertificateException import java.security.cert.PKIXBuilderParameters -import java.security.cert.PKIXRevocationChecker import java.security.cert.X509CertSelector import java.security.cert.X509Certificate import java.util.concurrent.Executor +import java.util.concurrent.ThreadPoolExecutor import javax.net.ssl.CertPathTrustManagerParameters import javax.net.ssl.KeyManagerFactory import javax.net.ssl.SNIHostName @@ -46,7 +50,6 @@ import javax.net.ssl.SSLEngine import javax.net.ssl.TrustManagerFactory import javax.net.ssl.X509ExtendedTrustManager import javax.security.auth.x500.X500Principal -import kotlin.system.measureTimeMillis private const val HOSTNAME_FORMAT = "%s.corda.net" internal const val DEFAULT = "default" @@ -58,7 +61,6 @@ internal val logger = LoggerFactory.getLogger("net.corda.nodeapi.internal.proton /** * Returns all the CRL distribution points in the certificate as [URI]s along with the CRL issuer names, if any. */ -@Suppress("ComplexMethod") fun X509Certificate.distributionPoints(): Map?> { logger.debug { "Checking CRLDPs for $subjectX500Principal" } @@ -117,6 +119,14 @@ fun certPathToString(certPath: Array?): String { return certPath.joinToString(System.lineSeparator()) { " ${it.toSimpleString()}" } } +/** + * Create an executor for processing SSL handshake tasks asynchronously (see [SSLEngine.getDelegatedTask]). The max number of threads is 3, + * which is the typical number of CRLs expected in a Corda TLS cert path. The executor needs to be passed to the [SslHandler] constructor. + */ +fun sslDelegatedTaskExecutor(parentPoolName: String): ThreadPoolExecutor { + return namedThreadPoolExecutor(maxPoolSize = 3, poolName = "$parentPoolName-ssltask") +} + @VisibleForTesting class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() { companion object { @@ -179,32 +189,11 @@ class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509Ex } -private object LoggingImmediateExecutor : Executor { - - override fun execute(command: Runnable) { - val log = LoggerFactory.getLogger(javaClass) - - @Suppress("TooGenericExceptionCaught", "MagicNumber") // log and rethrow all exceptions - try { - val commandName = command::class.qualifiedName?.let { "[$it]" } ?: "" - log.debug("Entering SSL command $commandName") - val elapsedTime = measureTimeMillis { command.run() } - log.debug("Exiting SSL command $elapsedTime millis") - if (elapsedTime > 100) { - log.info("Command: $commandName took $elapsedTime millis to execute") - } - } - catch (ex: Exception) { - log.error("Caught exception in SSL handler executor", ex) - throw ex - } - } -} - internal fun createClientSslHandler(target: NetworkHostAndPort, expectedRemoteLegalNames: Set, keyManagerFactory: KeyManagerFactory, - trustManagerFactory: TrustManagerFactory): SslHandler { + trustManagerFactory: TrustManagerFactory, + delegateTaskExecutor: Executor): SslHandler { val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory) val sslEngine = sslContext.createSSLEngine(target.host, target.port) sslEngine.useClientMode = true @@ -216,14 +205,15 @@ internal fun createClientSslHandler(target: NetworkHostAndPort, sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) sslEngine.sslParameters = sslParameters } - return SslHandler(sslEngine, false, LoggingImmediateExecutor) + return SslHandler(sslEngine, false, delegateTaskExecutor) } internal fun createClientOpenSslHandler(target: NetworkHostAndPort, expectedRemoteLegalNames: Set, keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory, - alloc: ByteBufAllocator): SslHandler { + alloc: ByteBufAllocator, + delegateTaskExecutor: Executor): SslHandler { val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build() val sslEngine = sslContext.newEngine(alloc, target.host, target.port) sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray() @@ -233,12 +223,13 @@ internal fun createClientOpenSslHandler(target: NetworkHostAndPort, sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) sslEngine.sslParameters = sslParameters } - return SslHandler(sslEngine, false, LoggingImmediateExecutor) + return SslHandler(sslEngine, false, delegateTaskExecutor) } internal fun createServerSslHandler(keyStore: CertificateStore, keyManagerFactory: KeyManagerFactory, - trustManagerFactory: TrustManagerFactory): SslHandler { + trustManagerFactory: TrustManagerFactory, + delegateTaskExecutor: Executor): SslHandler { val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory) val sslEngine = sslContext.createSSLEngine() sslEngine.useClientMode = false @@ -249,39 +240,29 @@ internal fun createServerSslHandler(keyStore: CertificateStore, val sslParameters = sslEngine.sslParameters sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore)) sslEngine.sslParameters = sslParameters - return SslHandler(sslEngine, false, LoggingImmediateExecutor) + return SslHandler(sslEngine, false, delegateTaskExecutor) } internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory, - alloc: ByteBufAllocator): SslHandler { + alloc: ByteBufAllocator, + delegateTaskExecutor: Executor): SslHandler { val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build() val sslEngine = sslContext.newEngine(alloc) sslEngine.useClientMode = false - return SslHandler(sslEngine, false, LoggingImmediateExecutor) + return SslHandler(sslEngine, false, delegateTaskExecutor) } -fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SSLContext { +fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory?): SSLContext { val sslContext = SSLContext.getInstance("TLS") - val keyManagers = keyManagerFactory.keyManagers - val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java) - .map { LoggingTrustManagerWrapper(it) }.toTypedArray() - sslContext.init(keyManagers, trustManagers, newSecureRandom()) + val trustManagers = trustManagerFactory + ?.trustManagers + ?.map { if (it is X509ExtendedTrustManager) LoggingTrustManagerWrapper(it) else it } + ?.toTypedArray() + sslContext.init(keyManagerFactory.keyManagers, trustManagers, newSecureRandom()) return sslContext } -fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore, - revocationConfig: RevocationConfig): CertPathTrustManagerParameters { - return initialiseTrustStoreAndEnableCrlChecking(trustStore, revocationConfig.createPKIXRevocationChecker()) -} - -fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore, - revocationChecker: PKIXRevocationChecker): CertPathTrustManagerParameters { - val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector()) - pkixParams.addCertPathChecker(revocationChecker) - return CertPathTrustManagerParameters(pkixParams) -} - /** * Creates a special SNI handler used only when openSSL is used for AMQPServer */ @@ -296,14 +277,13 @@ internal fun createServerSNIOpenSniHandler(keyManagerFactoriesMap: Map { @@ -325,9 +305,38 @@ internal fun splitKeystore(config: AMQPConfiguration): Map AllowAllRevocationChecker + RevocationConfig.Mode.EXTERNAL_SOURCE -> { + val externalCrlSource = requireNotNull(revocationConfig.externalCrlSource) { + "externalCrlSource must be specfied for EXTERNAL_SOURCE" + } + CordaRevocationChecker(externalCrlSource, softFail = true) + } + RevocationConfig.Mode.SOFT_FAIL -> CordaRevocationChecker(crlSource, softFail = true) + RevocationConfig.Mode.HARD_FAIL -> CordaRevocationChecker(crlSource, softFail = false) + } + val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector()) + pkixParams.addCertPathChecker(revocationChecker) + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + trustManagerFactory.init(CertPathTrustManagerParameters(pkixParams)) + return trustManagerFactory +} /** * Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSource.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSource.kt index db984f11b8..ee589e73a9 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSource.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSource.kt @@ -5,6 +5,9 @@ import com.github.benmanes.caffeine.cache.LoadingCache import net.corda.core.internal.readFully import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.minutes +import net.corda.core.utilities.seconds import net.corda.nodeapi.internal.crypto.X509CertificateFactory import net.corda.nodeapi.internal.crypto.toSimpleString import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource @@ -12,60 +15,71 @@ import net.corda.nodeapi.internal.protonwrapper.netty.distributionPoints import java.net.URI import java.security.cert.X509CRL import java.security.cert.X509Certificate -import java.util.concurrent.TimeUnit +import java.time.Duration import javax.security.auth.x500.X500Principal /** - * [CrlSource] which downloads CRLs from the distribution points in the X509 certificate. + * [CrlSource] which downloads CRLs from the distribution points in the X509 certificate and caches them. */ @Suppress("TooGenericExceptionCaught") -class CertDistPointCrlSource : CrlSource { +class CertDistPointCrlSource(cacheSize: Long = DEFAULT_CACHE_SIZE, + cacheExpiry: Duration = DEFAULT_CACHE_EXPIRY, + private val connectTimeout: Duration = DEFAULT_CONNECT_TIMEOUT, + private val readTimeout: Duration = DEFAULT_READ_TIMEOUT) : CrlSource { companion object { private val logger = contextLogger() // The default SSL handshake timeout is 60s (DEFAULT_SSL_HANDSHAKE_TIMEOUT). Considering there are 3 CRLs endpoints to check in a // node handshake, we want to keep the total timeout within that. - private const val DEFAULT_CONNECT_TIMEOUT = 9_000 - private const val DEFAULT_READ_TIMEOUT = 9_000 + private val DEFAULT_CONNECT_TIMEOUT = 9.seconds + private val DEFAULT_READ_TIMEOUT = 9.seconds private const val DEFAULT_CACHE_SIZE = 185L // Same default as the JDK (URICertStore) - private const val DEFAULT_CACHE_EXPIRY = 5 * 60 * 1000L + private val DEFAULT_CACHE_EXPIRY = 5.minutes - private val cache: LoadingCache = Caffeine.newBuilder() - .maximumSize(java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE)) - .expireAfterWrite(java.lang.Long.getLong("net.corda.dpcrl.cache.expiry", DEFAULT_CACHE_EXPIRY), TimeUnit.MILLISECONDS) - .build(::retrieveCRL) + val SINGLETON = CertDistPointCrlSource( + cacheSize = java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE), + cacheExpiry = java.lang.Long.getLong("net.corda.dpcrl.cache.expiry")?.let(Duration::ofMillis) ?: DEFAULT_CACHE_EXPIRY, + connectTimeout = java.lang.Long.getLong("net.corda.dpcrl.connect.timeout")?.let(Duration::ofMillis) ?: DEFAULT_CONNECT_TIMEOUT, + readTimeout = java.lang.Long.getLong("net.corda.dpcrl.read.timeout")?.let(Duration::ofMillis) ?: DEFAULT_READ_TIMEOUT + ) + } - private val connectTimeout = Integer.getInteger("net.corda.dpcrl.connect.timeout", DEFAULT_CONNECT_TIMEOUT) - private val readTimeout = Integer.getInteger("net.corda.dpcrl.read.timeout", DEFAULT_READ_TIMEOUT) + private val cache: LoadingCache = Caffeine.newBuilder() + .maximumSize(cacheSize) + .expireAfterWrite(cacheExpiry) + .build(::retrieveCRL) - private fun retrieveCRL(uri: URI): X509CRL { - val start = System.currentTimeMillis() - val bytes = try { - val conn = uri.toURL().openConnection() - conn.connectTimeout = connectTimeout - conn.readTimeout = readTimeout - // Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes - // in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException - // into CRLException and drops the cause chain. - conn.getInputStream().readFully() - } catch (e: Exception) { - if (logger.isDebugEnabled) { - logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e) - } - throw e + private fun retrieveCRL(uri: URI): X509CRL { + val start = System.currentTimeMillis() + val bytes = try { + val conn = uri.toURL().openConnection() + conn.connectTimeout = connectTimeout.toMillis().toInt() + conn.readTimeout = readTimeout.toMillis().toInt() + // Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes + // in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException + // into CRLException and drops the cause chain. + conn.getInputStream().readFully() + } catch (e: Exception) { + if (logger.isDebugEnabled) { + logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e) } - val duration = System.currentTimeMillis() - start - val crl = try { - X509CertificateFactory().generateCRL(bytes.inputStream()) - } catch (e: Exception) { - if (logger.isDebugEnabled) { - logger.debug("Invalid CRL from $uri (${duration}ms)", e) - } - throw e - } - logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" } - return crl + throw e } + val duration = System.currentTimeMillis() - start + val crl = try { + X509CertificateFactory().generateCRL(bytes.inputStream()) + } catch (e: Exception) { + if (logger.isDebugEnabled) { + logger.debug("Invalid CRL from $uri (${duration}ms)", e) + } + throw e + } + logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" } + return crl + } + + fun clearCache() { + cache.invalidateAll() } override fun fetch(certificate: X509Certificate): Set { diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffAlgorithmsTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffAlgorithmsTest.kt index 21c7bc8a94..e951c587c2 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffAlgorithmsTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffAlgorithmsTest.kt @@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom import net.corda.core.utilities.Try import net.corda.core.utilities.contextLogger import net.corda.nodeapi.internal.config.CertificateStore -import net.corda.nodeapi.internal.protonwrapper.netty.init +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import org.assertj.core.api.Assertions import org.junit.Rule import org.junit.Test @@ -161,11 +162,9 @@ class TlsDiffAlgorithmsTest(private val serverAlgo: String, private val clientAl private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext { return SSLContext.getInstance("TLS").apply { - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) + val keyManagerFactory = keyManagerFactory(keyStore) val keyManagers = keyManagerFactory.keyManagers - val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + val trustMgrFactory = trustManagerFactory(trustStore) val trustManagers = trustMgrFactory.trustManagers init(keyManagers, trustManagers, newSecureRandom()) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffProtocolsTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffProtocolsTest.kt index 46b6bf381c..0bb81e5627 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffProtocolsTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/crypto/TlsDiffProtocolsTest.kt @@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom import net.corda.core.utilities.Try import net.corda.core.utilities.contextLogger import net.corda.nodeapi.internal.config.CertificateStore -import net.corda.nodeapi.internal.protonwrapper.netty.init +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import org.assertj.core.api.Assertions import org.junit.Ignore import org.junit.Rule @@ -18,7 +19,6 @@ import java.io.IOException import java.net.InetAddress import java.net.InetSocketAddress import javax.net.ssl.* -import javax.net.ssl.SNIHostName import kotlin.concurrent.thread import kotlin.test.assertEquals import kotlin.test.assertFalse @@ -209,11 +209,9 @@ class TlsDiffProtocolsTest(private val serverAlgo: String, private val clientAlg private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext { return SSLContext.getInstance("TLS").apply { - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) + val keyManagerFactory = keyManagerFactory(keyStore) val keyManagers = keyManagerFactory.keyManagers - val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + val trustMgrFactory = trustManagerFactory(trustStore) val trustManagers = trustMgrFactory.trustManagers init(keyManagers, trustManagers, newSecureRandom()) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelperTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelperTest.kt index 8fbe5fd9d9..12eb6d3e35 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelperTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelperTest.kt @@ -1,5 +1,6 @@ package net.corda.nodeapi.internal.protonwrapper.netty +import io.netty.util.concurrent.ImmediateExecutor import net.corda.core.crypto.SecureHash import net.corda.core.identity.CordaX500Name import net.corda.core.utilities.NetworkHostAndPort @@ -8,10 +9,9 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS +import net.corda.testing.internal.fixedCrlSource import org.junit.Test -import javax.net.ssl.KeyManagerFactory import javax.net.ssl.SNIHostName -import javax.net.ssl.TrustManagerFactory import kotlin.test.assertEquals class SSLHelperTest { @@ -20,15 +20,21 @@ class SSLHelperTest { val legalName = CordaX500Name("Test", "London", "GB") val sslConfig = configureTestSSL(legalName) - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get()) - val keyStore = sslConfig.keyStore - keyManagerFactory.init(CertificateStore.fromFile(keyStore.path, keyStore.storePassword, keyStore.entryPassword, false)) - val trustStore = sslConfig.trustStore - trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(CertificateStore.fromFile(trustStore.path, trustStore.storePassword, trustStore.entryPassword, false), RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL))) + val trustManagerFactory = trustManagerFactoryWithRevocation( + sslConfig.trustStore.get(), + RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL), + fixedCrlSource(emptySet()) + ) - val sslHandler = createClientSslHandler(NetworkHostAndPort("localhost", 1234), setOf(legalName), keyManagerFactory, trustManagerFactory) + val sslHandler = createClientSslHandler( + NetworkHostAndPort("localhost", 1234), + setOf(legalName), + keyManagerFactory, + trustManagerFactory, + ImmediateExecutor.INSTANCE + ) val legalNameHash = SecureHash.sha256(legalName.toString()).toString().take(32).toLowerCase() // These hardcoded values must not be changed, something is broken if you have to change these hardcoded values. diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSourceTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSourceTest.kt index 489c66de32..66c17e4a39 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSourceTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CertDistPointCrlSourceTest.kt @@ -2,15 +2,13 @@ package net.corda.nodeapi.internal.revocation import net.corda.core.crypto.Crypto import net.corda.core.utilities.NetworkHostAndPort -import net.corda.nodeapi.internal.createDevNodeCa -import net.corda.testing.core.ALICE_NAME +import net.corda.nodeapi.internal.DEV_INTERMEDIATE_CA import net.corda.testing.node.internal.network.CrlServer import org.assertj.core.api.Assertions.assertThat import org.bouncycastle.jce.provider.BouncyCastleProvider import org.junit.After import org.junit.Before import org.junit.Test -import java.math.BigInteger class CertDistPointCrlSourceTest { private lateinit var crlServer: CrlServer @@ -39,13 +37,14 @@ class CertDistPointCrlSourceTest { assertThat(single().revokedCertificates).isNull() } - val nodeCaCert = crlServer.replaceNodeCertDistPoint(createDevNodeCa(crlServer.intermediateCa, ALICE_NAME).certificate) + crlSource.clearCache() - crlServer.revokedNodeCerts += listOf(BigInteger.ONE, BigInteger.TEN) - with(crlSource.fetch(nodeCaCert)) { // Use a different cert to avoid the cache + crlServer.revokedIntermediateCerts += DEV_INTERMEDIATE_CA.certificate + with(crlSource.fetch(crlServer.intermediateCa.certificate)) { assertThat(size).isEqualTo(1) val revokedCertificates = single().revokedCertificates - assertThat(revokedCertificates.map { it.serialNumber }).containsExactlyInAnyOrder(BigInteger.ONE, BigInteger.TEN) + // This also tests clearCache() works. + assertThat(revokedCertificates.map { it.serialNumber }).containsExactly(DEV_INTERMEDIATE_CA.certificate.serialNumber) } } } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationCheckerTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationCheckerTest.kt index 6ccf8ce594..6dbfcd4515 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationCheckerTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/CordaRevocationCheckerTest.kt @@ -5,7 +5,7 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.crypto.X509Utilities -import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource +import net.corda.testing.internal.fixedCrlSource import org.bouncycastle.jcajce.provider.asymmetric.x509.CertificateFactory import org.junit.Test import java.math.BigInteger @@ -41,10 +41,8 @@ class CordaRevocationCheckerTest { val resourceAsStream = javaClass.getResourceAsStream("/net/corda/nodeapi/internal/protonwrapper/netty/doorman.crl") val crl = CertificateFactory().engineGenerateCRL(resourceAsStream) as X509CRL - val crlSource = object : CrlSource { - override fun fetch(certificate: X509Certificate): Set = setOf(crl) - } - val checker = CordaRevocationChecker(crlSource, + val checker = CordaRevocationChecker( + crlSource = fixedCrlSource(setOf(crl)), softFail = true, dateSource = { Date.from(date.atStartOfDay().toInstant(ZoneOffset.UTC)) } ) diff --git a/node/src/test/kotlin/net/corda/node/internal/artemis/RevocationCheckTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/RevocationTest.kt similarity index 59% rename from node/src/test/kotlin/net/corda/node/internal/artemis/RevocationCheckTest.kt rename to node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/RevocationTest.kt index 2e984eb3b5..675f222f92 100644 --- a/node/src/test/kotlin/net/corda/node/internal/artemis/RevocationCheckTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/revocation/RevocationTest.kt @@ -1,20 +1,16 @@ -package net.corda.node.internal.artemis +package net.corda.nodeapi.internal.revocation import net.corda.core.crypto.Crypto -import net.corda.core.utilities.days -import net.corda.node.internal.artemis.CertificateChainCheckPolicy.RevocationCheck +import net.corda.nodeapi.internal.config.CertificateStore +import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.CertificateType +import net.corda.nodeapi.internal.crypto.X509KeyStore import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation +import net.corda.testing.core.createCRL import org.bouncycastle.asn1.x500.X500Name -import org.bouncycastle.asn1.x509.CRLReason -import org.bouncycastle.asn1.x509.Extension -import org.bouncycastle.asn1.x509.ExtensionsGenerator -import org.bouncycastle.asn1.x509.GeneralName -import org.bouncycastle.asn1.x509.GeneralNames -import org.bouncycastle.asn1.x509.IssuingDistributionPoint -import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder -import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder import org.junit.Before import org.junit.Rule import org.junit.Test @@ -22,15 +18,18 @@ import org.junit.rules.TemporaryFolder import org.junit.runner.RunWith import org.junit.runners.Parameterized import java.io.File +import java.security.KeyPair import java.security.KeyStore import java.security.PrivateKey +import java.security.cert.CertificateException import java.security.cert.X509Certificate import java.util.* +import javax.net.ssl.X509TrustManager import javax.security.auth.x500.X500Principal -import kotlin.test.assertFails +import kotlin.test.assertFailsWith @RunWith(Parameterized::class) -class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { +class RevocationTest(private val revocationMode: RevocationConfig.Mode) { companion object { @JvmStatic @Parameterized.Parameters(name = "revocationMode = {0}") @@ -45,8 +44,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { private lateinit var doormanCRL: File private lateinit var tlsCRL: File - private val keyStore = KeyStore.getInstance("JKS") - private val trustStore = KeyStore.getInstance("JKS") + private lateinit var trustManager: X509TrustManager private val rootKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) private val tlsCRLIssuerKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) @@ -61,9 +59,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { private lateinit var tlsCert: X509Certificate private val chain - get() = listOf(tlsCert, nodeCACert, doormanCert, rootCert).map { - javax.security.cert.X509Certificate.getInstance(it.encoded) - }.toTypedArray() + get() = arrayOf(tlsCert, nodeCACert, doormanCert, rootCert) @Before fun before() { @@ -74,10 +70,18 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { rootCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=root"), rootKeyPair) tlsCRLIssuerCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=issuer"), tlsCRLIssuerKeyPair) + val trustStore = KeyStore.getInstance("JKS") trustStore.load(null, null) trustStore.setCertificateEntry("cordatlscrlsigner", tlsCRLIssuerCert) trustStore.setCertificateEntry("cordarootca", rootCert) + val trustManagerFactory = trustManagerFactoryWithRevocation( + CertificateStore.of(X509KeyStore(trustStore, "pass"), "pass", "pass"), + RevocationConfigImpl(revocationMode), + CertDistPointCrlSource() + ) + trustManager = trustManagerFactory.trustManagers.single() as X509TrustManager + doormanCert = X509Utilities.createCertificate( CertificateType.INTERMEDIATE_CA, rootCert, rootKeyPair, X500Principal("CN=doorman"), doormanKeyPair.public, crlDistPoint = rootCRL.toURI().toString() @@ -91,43 +95,34 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) ) - rootCRL.createCRL(rootCert, rootKeyPair.private, false) - doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false) - tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true) + rootCRL.writeCRL(rootCert, rootKeyPair.private, false) + doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false) + tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true) } - private fun File.createCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) { - val builder = JcaX509v2CRLBuilder(certificate.subjectX500Principal, Date()) - builder.setNextUpdate(Date.from(Date().toInstant() + 7.days)) - builder.addExtension(Extension.issuingDistributionPoint, true, IssuingDistributionPoint(null, indirect, false)) - revoked.forEach { - val extensionsGenerator = ExtensionsGenerator() - extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(CRLReason.keyCompromise)) - // Certificate issuer is required for indirect CRL - val certificateIssuerName = X500Name.getInstance(it.issuerX500Principal.encoded) - extensionsGenerator.addExtension(Extension.certificateIssuer, true, GeneralNames(GeneralName(certificateIssuerName))) - builder.addCRLEntry(it.serialNumber, Date(), extensionsGenerator.generate()) - } - val holder = builder.build(JcaContentSignerBuilder("SHA256withECDSA").setProvider(Crypto.findProvider("BC")).build(privateKey)) - outputStream().use { it.write(holder.encoded) } + private fun File.writeCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) { + val crl = createCRL( + CertificateAndKeyPair(certificate, KeyPair(certificate.publicKey, privateKey)), + revoked.asList(), + indirect = indirect + ) + writeBytes(crl.encoded) } - private fun assertFailsFor(vararg modes: RevocationConfig.Mode, block: () -> Unit) { - if (revocationMode in modes) assertFails(block) else block() + private fun assertFailsFor(vararg modes: RevocationConfig.Mode) { + if (revocationMode in modes) assertFailsWith(CertificateException::class, ::doRevocationCheck) else doRevocationCheck() } @Test(timeout = 300_000) fun `ok with empty CRLs`() { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) + doRevocationCheck() } @Test(timeout = 300_000) fun `soft fail with revoked TLS certificate`() { - tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert) + tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert) - assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -138,9 +133,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) ) - assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -150,9 +143,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name("CN=unknown") ) - assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -162,9 +153,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlDistPoint = tlsCRL.toURI().toString() ) - assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -174,18 +163,16 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=other"), otherKeyPair.public, crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) ) - tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert) + tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert) - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) + doRevocationCheck() } @Test(timeout = 300_000) fun `soft fail with revoked node CA certificate`() { - doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, nodeCACert) + doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, nodeCACert) - assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -195,9 +182,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/doorman" ) - assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) - } + assertFailsFor(RevocationConfig.Mode.HARD_FAIL) } @Test(timeout = 300_000) @@ -207,8 +192,12 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=other"), otherKeyPair.public, crlDistPoint = doormanCRL.toURI().toString() ) - doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, otherCert) + doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, otherCert) - RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) + doRevocationCheck() + } + + private fun doRevocationCheck() { + trustManager.checkClientTrusted(chain, "ECDHE_ECDSA") } } diff --git a/node/build.gradle b/node/build.gradle index 493d4f83cc..fe895177f7 100644 --- a/node/build.gradle +++ b/node/build.gradle @@ -264,8 +264,6 @@ tasks.register('integrationTest', Test) { testClassesDirs = sourceSets.integrationTest.output.classesDirs classpath = sourceSets.integrationTest.runtimeClasspath maxParallelForks = (System.env.CORDA_NODE_INT_TESTING_FORKS == null) ? 1 : "$System.env.CORDA_NODE_INT_TESTING_FORKS".toInteger() - // CertificateRevocationListNodeTests - systemProperty 'net.corda.dpcrl.connect.timeout', '4000' } tasks.register('slowIntegrationTest', Test) { diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt index a5b49f4fe1..050142097e 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt @@ -14,12 +14,15 @@ import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration -import net.corda.nodeapi.internal.protonwrapper.netty.init -import net.corda.nodeapi.internal.protonwrapper.netty.initialiseTrustStoreAndEnableCrlChecking +import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME import net.corda.testing.driver.internal.incrementalPortAllocation +import net.corda.testing.internal.fixedCrlSource import org.junit.Assume.assumeFalse import org.junit.Before import org.junit.Rule @@ -96,11 +99,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) { override val maxMessageSize: Int = MAX_MESSAGE_SIZE } - serverKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - serverTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + serverKeyManagerFactory = keyManagerFactory(keyStore) - serverKeyManagerFactory.init(keyStore) - serverTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(serverAmqpConfig.trustStore, serverAmqpConfig.revocationConfig)) + serverTrustManagerFactory = trustManagerFactoryWithRevocation( + serverAmqpConfig.trustStore, + RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL), + fixedCrlSource(emptySet()) + ) } private fun setupClientCertificates() { @@ -127,11 +132,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) { override val sslHandshakeTimeout: Duration = 3.seconds } - clientKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - clientTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + clientKeyManagerFactory = keyManagerFactory(keyStore) - clientKeyManagerFactory.init(keyStore) - clientTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(clientAmqpConfig.trustStore, clientAmqpConfig.revocationConfig)) + clientTrustManagerFactory = trustManagerFactoryWithRevocation( + clientAmqpConfig.trustStore, + RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL), + fixedCrlSource(emptySet()) + ) } @Test(timeout = 300_000) diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt index 4bfdfb2a93..06e12a966f 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt @@ -1,3 +1,5 @@ +@file:Suppress("LongParameterList") + package net.corda.node.amqp import com.nhaarman.mockito_kotlin.doReturn @@ -5,10 +7,10 @@ import com.nhaarman.mockito_kotlin.whenever import net.corda.core.crypto.Crypto import net.corda.core.identity.CordaX500Name import net.corda.core.internal.div -import net.corda.core.internal.rootCause import net.corda.core.internal.times -import net.corda.core.toFuture import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.minutes +import net.corda.core.utilities.seconds import net.corda.coretesting.internal.rigorousMock import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.node.services.config.NodeConfiguration @@ -18,63 +20,67 @@ import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.config.CertificateStoreSupplier import net.corda.nodeapi.internal.config.MutualSslConfiguration -import net.corda.nodeapi.internal.crypto.X509Utilities +import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_CA +import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer +import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionChange import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig +import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.CHARLIE_NAME import net.corda.testing.core.MAX_MESSAGE_SIZE import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.node.internal.network.CrlServer import net.corda.testing.node.internal.network.CrlServer.Companion.EMPTY_CRL -import net.corda.testing.node.internal.network.CrlServer.Companion.FORBIDDEN_CRL import net.corda.testing.node.internal.network.CrlServer.Companion.NODE_CRL import net.corda.testing.node.internal.network.CrlServer.Companion.withCrlDistPoint import org.apache.activemq.artemis.api.core.RoutingType import org.assertj.core.api.Assertions.assertThat -import org.assertj.core.api.Assertions.assertThatIllegalArgumentException import org.bouncycastle.jce.provider.BouncyCastleProvider import org.junit.After import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder -import java.net.SocketTimeoutException +import java.io.Closeable import java.security.cert.X509Certificate import java.time.Duration +import java.util.concurrent.BlockingQueue +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger -import kotlin.test.assertEquals +import java.util.stream.IntStream -@Suppress("LongParameterList") -class CertificateRevocationListNodeTests { +abstract class AbstractServerRevocationTest { @Rule @JvmField val temporaryFolder = TemporaryFolder() private val portAllocation = incrementalPortAllocation() - private val serverPort = portAllocation.nextPort() + protected val serverPort = portAllocation.nextPort() - private lateinit var crlServer: CrlServer - private lateinit var amqpServer: AMQPServer - private lateinit var amqpClient: AMQPClient + protected lateinit var crlServer: CrlServer + private val amqpClients = ArrayList() - private abstract class AbstractNodeConfiguration : NodeConfiguration + protected lateinit var defaultCrlDistPoints: CrlDistPoints + + protected abstract class AbstractNodeConfiguration : NodeConfiguration companion object { private val unreachableIpCounter = AtomicInteger(1) - private val crlConnectTimeout = Duration.ofMillis(System.getProperty("net.corda.dpcrl.connect.timeout").toLong()) + val crlConnectTimeout = 2.seconds /** * Use this method to get a unqiue unreachable IP address. Subsequent uses of the same IP for connection timeout testing purposes * may not work as the OS process may cache the timeout result. */ - private fun newUnreachableIpAddress(): String { + private fun newUnreachableIpAddress(): NetworkHostAndPort { check(unreachableIpCounter.get() != 255) - return "10.255.255.${unreachableIpCounter.getAndIncrement()}" + return NetworkHostAndPort("10.255.255", unreachableIpCounter.getAndIncrement()) } } @@ -84,252 +90,190 @@ class CertificateRevocationListNodeTests { Crypto.findProvider(BouncyCastleProvider.PROVIDER_NAME) crlServer = CrlServer(NetworkHostAndPort("localhost", 0)) crlServer.start() + defaultCrlDistPoints = CrlDistPoints(crlServer.hostAndPort) } @After fun tearDown() { - if (::amqpClient.isInitialized) { - amqpClient.close() - } - if (::amqpServer.isInitialized) { - amqpServer.close() - } + amqpClients.parallelStream().forEach(AMQPClient::close) if (::crlServer.isInitialized) { crlServer.close() } } @Test(timeout=300_000) - fun `AMQP server connection works and soft fail is enabled`() { - verifyAMQPConnection( + fun `connection succeeds when soft fail is enabled`() { + verifyConnection( crlCheckSoftFail = true, - expectedConnectStatus = true + expectedConnectedStatus = true ) } @Test(timeout=300_000) - fun `AMQP server connection works and soft fail is disabled`() { - verifyAMQPConnection( + fun `connection succeeds when soft fail is disabled`() { + verifyConnection( crlCheckSoftFail = false, - expectedConnectStatus = true + expectedConnectedStatus = true ) } @Test(timeout=300_000) - fun `AMQP server connection fails when client's certificate is revoked and soft fail is enabled`() { - verifyAMQPConnection( + fun `connection fails when client's certificate is revoked and soft fail is enabled`() { + verifyConnection( crlCheckSoftFail = true, revokeClientCert = true, - expectedConnectStatus = false + expectedConnectedStatus = false ) } @Test(timeout=300_000) - fun `AMQP server connection fails when client's certificate is revoked and soft fail is disabled`() { - verifyAMQPConnection( + fun `connection fails when client's certificate is revoked and soft fail is disabled`() { + verifyConnection( crlCheckSoftFail = false, revokeClientCert = true, - expectedConnectStatus = false + expectedConnectedStatus = false ) } @Test(timeout=300_000) - fun `AMQP server connection fails when servers's certificate is revoked and soft fail is enabled`() { - verifyAMQPConnection( + fun `connection fails when server's certificate is revoked and soft fail is enabled`() { + verifyConnection( crlCheckSoftFail = true, revokeServerCert = true, - expectedConnectStatus = false + expectedConnectedStatus = false ) } @Test(timeout=300_000) - fun `AMQP server connection fails when servers's certificate is revoked and soft fail is disabled`() { - verifyAMQPConnection( + fun `connection fails when server's certificate is revoked and soft fail is disabled`() { + verifyConnection( crlCheckSoftFail = false, revokeServerCert = true, - expectedConnectStatus = false + expectedConnectedStatus = false ) } @Test(timeout=300_000) - fun `AMQP server connection succeeds when CRL cannot be obtained and soft fail is enabled`() { - verifyAMQPConnection( + fun `connection succeeds when CRL cannot be obtained and soft fail is enabled`() { + verifyConnection( crlCheckSoftFail = true, - nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/invalid.crl", - expectedConnectStatus = true + clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"), + expectedConnectedStatus = true ) } @Test(timeout=300_000) - fun `AMQP server connection fails when CRL cannot be obtained and soft fail is disabled`() { - verifyAMQPConnection( + fun `connection fails when CRL cannot be obtained and soft fail is disabled`() { + verifyConnection( crlCheckSoftFail = false, - nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/invalid.crl", - expectedConnectStatus = false + clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"), + expectedConnectedStatus = false ) } @Test(timeout=300_000) - fun `AMQP server connection succeeds when CRL is not defined and soft fail is enabled`() { - verifyAMQPConnection( + fun `connection succeeds when CRL is not defined for node CA cert and soft fail is enabled`() { + verifyConnection( crlCheckSoftFail = true, - nodeCrlDistPoint = null, - expectedConnectStatus = true + clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null), + expectedConnectedStatus = true ) } @Test(timeout=300_000) - fun `AMQP server connection fails when CRL is not defined and soft fail is disabled`() { - verifyAMQPConnection( + fun `connection fails when CRL is not defined for node CA cert and soft fail is disabled`() { + verifyConnection( crlCheckSoftFail = false, - nodeCrlDistPoint = null, - expectedConnectStatus = false + clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null), + expectedConnectedStatus = false ) } @Test(timeout=300_000) - fun `AMQP server connection succeeds when CRL retrieval is forbidden and soft fail is enabled`() { - verifyAMQPConnection( + fun `connection succeeds when CRL is not defined for TLS cert and soft fail is enabled`() { + verifyConnection( crlCheckSoftFail = true, - nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL", - expectedConnectStatus = true + clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null), + expectedConnectedStatus = true ) } @Test(timeout=300_000) - fun `AMQP server connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() { - verifyAMQPConnection( - crlCheckSoftFail = true, - nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl", - sslHandshakeTimeout = crlConnectTimeout * 3, - expectedConnectStatus = true + fun `connection fails when CRL is not defined for TLS cert and soft fail is disabled`() { + verifyConnection( + crlCheckSoftFail = false, + clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null), + expectedConnectedStatus = false ) - val timeoutExceptions = (amqpServer.softFailExceptions + amqpClient.softFailExceptions) - .map { it.rootCause } - .filterIsInstance() - assertThat(timeoutExceptions).isNotEmpty } @Test(timeout=300_000) - fun `AMQP server connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() { - verifyAMQPConnection( + fun `connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() { + verifyConnection( + crlCheckSoftFail = true, + sslHandshakeTimeout = crlConnectTimeout * 4, + clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()), + expectedConnectedStatus = true + ) + } + + @Test(timeout=300_000) + fun `connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() { + verifyConnection( crlCheckSoftFail = true, - nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl", sslHandshakeTimeout = crlConnectTimeout / 2, - expectedConnectStatus = false + clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()), + expectedConnectedStatus = false ) } - @Test(timeout=300_000) - fun `verify CRL algorithms`() { - val crl = crlServer.createRevocationList( - "SHA256withECDSA", - crlServer.rootCa, - EMPTY_CRL, - true, - emptyList() + @Test(timeout = 300_000) + fun `influx of new clients during CRL endpoint downtime does not cause existing connections to drop`() { + val serverCrlSource = CertDistPointCrlSource() + // Start the server and verify the first client has connected + val firstClientConnectionChangeStatus = verifyConnection( + crlCheckSoftFail = true, + crlSource = serverCrlSource, + // In general, N remoting threads will naturally support N-1 new handshaking clients plus one thread for heartbeating with + // existing clients. The trick is to make sure at least N new clients are also supported. + remotingThreads = 2, + expectedConnectedStatus = true ) - // This should pass. - crl.verify(crlServer.rootCa.keyPair.public) - // Try changing the algorithm to EC will fail. - assertThatIllegalArgumentException().isThrownBy { - crlServer.createRevocationList( - "EC", - crlServer.rootCa, - EMPTY_CRL, - true, - emptyList() + // Now simulate the CRL endpoint becoming very slow/unreachable + crlServer.delay = 10.minutes + // And pretend enough time has elapsed that the cached CRLs have expired and need downloading again + serverCrlSource.clearCache() + + // Now a bunch of new clients have arrived and want to handshake with the server, which will potentially cause the server's Netty + // threads to be tied up in trying to download the CRLs. + IntStream.range(0, 2).parallel().forEach { clientIndex -> + val (newClient, _) = createAMQPClient( + serverPort, + crlCheckSoftFail = true, + legalName = CordaX500Name("NewClient$clientIndex", "London", "GB"), + crlDistPoints = defaultCrlDistPoints ) - }.withMessage("Unknown signature type requested: EC") + newClient.start() + } + + // Make sure there are no further connection change updates, i.e. the first client stays connected throughout this whole saga + assertThat(firstClientConnectionChangeStatus.poll(30, TimeUnit.SECONDS)).isNull() } - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with soft fail CRL check`() { - verifyArtemisConnection( - crlCheckSoftFail = true, - crlCheckArtemisServer = true, - expectedStatus = MessageStatus.Acknowledged - ) - } + protected abstract fun verifyConnection(crlCheckSoftFail: Boolean, + crlSource: CertDistPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout), + sslHandshakeTimeout: Duration? = null, + remotingThreads: Int? = null, + clientCrlDistPoints: CrlDistPoints = defaultCrlDistPoints, + revokeClientCert: Boolean = false, + revokeServerCert: Boolean = false, + expectedConnectedStatus: Boolean): BlockingQueue - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with hard fail CRL check`() { - verifyArtemisConnection( - crlCheckSoftFail = false, - crlCheckArtemisServer = true, - expectedStatus = MessageStatus.Acknowledged - ) - } - - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with soft fail CRL check on unavailable URL`() { - verifyArtemisConnection( - crlCheckSoftFail = true, - crlCheckArtemisServer = true, - expectedStatus = MessageStatus.Acknowledged, - nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL" - ) - } - - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with soft fail CRL check on unreachable URL if CRL timeout is within SSL handshake timeout`() { - verifyArtemisConnection( - crlCheckSoftFail = true, - crlCheckArtemisServer = true, - expectedStatus = MessageStatus.Acknowledged, - nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl", - sslHandshakeTimeout = crlConnectTimeout * 3 - ) - } - - @Test(timeout = 300_000) - fun `Artemis server connection fails with soft fail CRL check on unreachable URL if CRL timeout is not within SSL handshake timeout`() { - verifyArtemisConnection( - crlCheckSoftFail = true, - crlCheckArtemisServer = true, - expectedConnected = false, - nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl", - sslHandshakeTimeout = crlConnectTimeout / 2 - ) - } - - @Test(timeout = 300_000) - fun `Artemis server connection fails with hard fail CRL check on unavailable URL`() { - verifyArtemisConnection( - crlCheckSoftFail = false, - crlCheckArtemisServer = true, - expectedStatus = MessageStatus.Rejected, - nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL" - ) - } - - @Test(timeout = 300_000) - fun `Artemis server connection fails with soft fail CRL check on revoked node certificate`() { - verifyArtemisConnection( - crlCheckSoftFail = true, - crlCheckArtemisServer = true, - expectedStatus = MessageStatus.Rejected, - revokedNodeCert = true - ) - } - - @Test(timeout = 300_000) - fun `Artemis server connection succeeds with disabled CRL check on revoked node certificate`() { - verifyArtemisConnection( - crlCheckSoftFail = false, - crlCheckArtemisServer = false, - expectedStatus = MessageStatus.Acknowledged, - revokedNodeCert = true - ) - } - - private fun createAMQPClient(targetPort: Int, - crlCheckSoftFail: Boolean, - legalName: CordaX500Name, - nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL", - tlsCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$EMPTY_CRL", - maxMessageSize: Int = MAX_MESSAGE_SIZE): X509Certificate { + protected fun createAMQPClient(targetPort: Int, + crlCheckSoftFail: Boolean, + legalName: CordaX500Name, + crlDistPoints: CrlDistPoints): Pair { val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val certificatesDirectory = baseDirectory / "certificates" val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) @@ -343,31 +287,128 @@ class CertificateRevocationListNodeTests { doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail } clientConfig.configureWithDevSSLCertificate() - val nodeCert = recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, tlsCrlDistPoint) + val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer) val keyStore = clientConfig.p2pSslOptions.keyStore.get() val amqpConfig = object : AMQPConfiguration { override val keyStore = keyStore override val trustStore = clientConfig.p2pSslOptions.trustStore.get() - override val maxMessageSize: Int = maxMessageSize + override val maxMessageSize: Int = MAX_MESSAGE_SIZE + override val trace: Boolean = true } - amqpClient = AMQPClient( + val amqpClient = AMQPClient( listOf(NetworkHostAndPort("localhost", targetPort)), setOf(CHARLIE_NAME), amqpConfig, - threadPoolName = legalName.organisation + threadPoolName = legalName.organisation, + distPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout) ) + amqpClients += amqpClient + return Pair(amqpClient, nodeCert) + } - return nodeCert + protected fun AMQPClient.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus: Boolean): BlockingQueue { + val connectionChangeStatus = LinkedBlockingQueue() + onConnection.subscribe { connectionChangeStatus.add(it) } + start() + assertThat(connectionChangeStatus.take().connected).isEqualTo(expectedConnectedStatus) + return connectionChangeStatus + } + + protected data class CrlDistPoints(val crlServerAddress: NetworkHostAndPort, + val nodeCa: String? = NODE_CRL, + val tls: String? = EMPTY_CRL) { + private val nodeCaCertCrlDistPoint: String? get() = nodeCa?.let { "http://$crlServerAddress/crl/$it" } + private val tlsCertCrlDistPoint: String? get() = tls?.let { "http://$crlServerAddress/crl/$it" } + + fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier, + p2pSslConfiguration: MutualSslConfiguration, + crlServer: CrlServer): X509Certificate { + val nodeKeyStore = signingCertificateStore.get() + val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_CA, nodeKeyStore.entryPassword) } + val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCertCrlDistPoint) + val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) + + nodeKeyStore.query { getCertificateChain(CORDA_CLIENT_CA) }.drop(2) + + nodeKeyStore.update { + internal.deleteEntry(CORDA_CLIENT_CA) + } + nodeKeyStore.update { + setPrivateKey(CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword) + } + + val sslKeyStore = p2pSslConfiguration.keyStore.get() + val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_TLS, sslKeyStore.entryPassword) } + val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCertCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal) + val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) + + sslKeyStore.query { getCertificateChain(CORDA_CLIENT_TLS) }.drop(3) + + sslKeyStore.update { + internal.deleteEntry(CORDA_CLIENT_TLS) + } + sslKeyStore.update { + setPrivateKey(CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword) + } + return newNodeCert + } + } +} + + +class AMQPServerRevocationTest : AbstractServerRevocationTest() { + private lateinit var amqpServer: AMQPServer + + @After + fun shutDown() { + if (::amqpServer.isInitialized) { + amqpServer.close() + } + } + + override fun verifyConnection(crlCheckSoftFail: Boolean, + crlSource: CertDistPointCrlSource, + sslHandshakeTimeout: Duration?, + remotingThreads: Int?, + clientCrlDistPoints: CrlDistPoints, + revokeClientCert: Boolean, + revokeServerCert: Boolean, + expectedConnectedStatus: Boolean): BlockingQueue { + val serverCert = createAMQPServer( + serverPort, + CHARLIE_NAME, + crlCheckSoftFail, + defaultCrlDistPoints, + crlSource, + sslHandshakeTimeout, + remotingThreads + ) + if (revokeServerCert) { + crlServer.revokedNodeCerts.add(serverCert) + } + amqpServer.start() + amqpServer.onReceive.subscribe { + it.complete(true) + } + val (client, clientCert) = createAMQPClient( + serverPort, + crlCheckSoftFail = crlCheckSoftFail, + legalName = ALICE_NAME, + crlDistPoints = clientCrlDistPoints + ) + if (revokeClientCert) { + crlServer.revokedNodeCerts.add(clientCert) + } + + return client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus) } private fun createAMQPServer(port: Int, legalName: CordaX500Name, crlCheckSoftFail: Boolean, - nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL", - tlsCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$EMPTY_CRL", - maxMessageSize: Int = MAX_MESSAGE_SIZE, - sslHandshakeTimeout: Duration? = null): X509Certificate { + crlDistPoints: CrlDistPoints, + distPointCrlSource: CertDistPointCrlSource, + sslHandshakeTimeout: Duration?, + remotingThreads: Int?): X509Certificate { check(!::amqpServer.isInitialized) val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val certificatesDirectory = baseDirectory / "certificates" @@ -381,92 +422,101 @@ class CertificateRevocationListNodeTests { doReturn(signingCertificateStore).whenever(it).signingCertificateStore } serverConfig.configureWithDevSSLCertificate() - val nodeCert = recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, tlsCrlDistPoint) + val serverCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer) val keyStore = serverConfig.p2pSslOptions.keyStore.get() val amqpConfig = object : AMQPConfiguration { override val keyStore = keyStore override val trustStore = serverConfig.p2pSslOptions.trustStore.get() override val revocationConfig = crlCheckSoftFail.toRevocationConfig() - override val maxMessageSize: Int = maxMessageSize + override val maxMessageSize: Int = MAX_MESSAGE_SIZE override val sslHandshakeTimeout: Duration = sslHandshakeTimeout ?: super.sslHandshakeTimeout } - amqpServer = AMQPServer("0.0.0.0", port, amqpConfig, threadPoolName = legalName.organisation) - return nodeCert - } - - private fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier, - p2pSslConfiguration: MutualSslConfiguration, - nodeCaCrlDistPoint: String?, - tlsCrlDistPoint: String?): X509Certificate { - val nodeKeyStore = signingCertificateStore.get() - val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA, nodeKeyStore.entryPassword) } - val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCrlDistPoint) - val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) + - nodeKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_CA) }.drop(2) - - nodeKeyStore.update { - internal.deleteEntry(X509Utilities.CORDA_CLIENT_CA) - } - nodeKeyStore.update { - setPrivateKey(X509Utilities.CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword) - } - - val sslKeyStore = p2pSslConfiguration.keyStore.get() - val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_TLS, sslKeyStore.entryPassword) } - val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal) - val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) + - sslKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_TLS) }.drop(3) - - sslKeyStore.update { - internal.deleteEntry(X509Utilities.CORDA_CLIENT_TLS) - } - sslKeyStore.update { - setPrivateKey(X509Utilities.CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword) - } - return newNodeCert - } - - private fun verifyAMQPConnection(crlCheckSoftFail: Boolean, - nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL", - revokeServerCert: Boolean = false, - revokeClientCert: Boolean = false, - sslHandshakeTimeout: Duration? = null, - expectedConnectStatus: Boolean) { - val serverCert = createAMQPServer( - serverPort, - CHARLIE_NAME, - crlCheckSoftFail = crlCheckSoftFail, - nodeCrlDistPoint = nodeCrlDistPoint, - sslHandshakeTimeout = sslHandshakeTimeout + amqpServer = AMQPServer( + "0.0.0.0", + port, + amqpConfig, + threadPoolName = legalName.organisation, + distPointCrlSource = distPointCrlSource, + remotingThreads = remotingThreads ) - if (revokeServerCert) { - crlServer.revokedNodeCerts.add(serverCert.serialNumber) + return serverCert + } +} + + +class ArtemisServerRevocationTest : AbstractServerRevocationTest() { + private lateinit var artemisNode: ArtemisNode + private var crlCheckArtemisServer = true + + @After + fun shutDown() { + if (::artemisNode.isInitialized) { + artemisNode.close() } - amqpServer.start() - amqpServer.onReceive.subscribe { - it.complete(true) - } - val clientCert = createAMQPClient( + } + + @Test(timeout = 300_000) + fun `connection succeeds with disabled CRL check on revoked node certificate`() { + crlCheckArtemisServer = false + verifyConnection( + crlCheckSoftFail = false, + revokeClientCert = true, + expectedConnectedStatus = true + ) + } + + override fun verifyConnection(crlCheckSoftFail: Boolean, + crlSource: CertDistPointCrlSource, + sslHandshakeTimeout: Duration?, + remotingThreads: Int?, + clientCrlDistPoints: CrlDistPoints, + revokeClientCert: Boolean, + revokeServerCert: Boolean, + expectedConnectedStatus: Boolean): BlockingQueue { + val (client, clientCert) = createAMQPClient( serverPort, - crlCheckSoftFail = crlCheckSoftFail, + crlCheckSoftFail = true, legalName = ALICE_NAME, - nodeCrlDistPoint = nodeCrlDistPoint + crlDistPoints = clientCrlDistPoints ) if (revokeClientCert) { - crlServer.revokedNodeCerts.add(clientCert.serialNumber) + crlServer.revokedNodeCerts.add(clientCert) } - val serverConnected = amqpServer.onConnection.toFuture() - amqpClient.start() - val serverConnect = serverConnected.get() - assertThat(serverConnect.connected).isEqualTo(expectedConnectStatus) + + val nodeCert = startArtemisNode( + CHARLIE_NAME, + crlCheckSoftFail, + defaultCrlDistPoints, + crlSource, + sslHandshakeTimeout, + remotingThreads + ) + if (revokeServerCert) { + crlServer.revokedNodeCerts.add(nodeCert) + } + + val queueName = "${P2P_PREFIX}Test" + artemisNode.client.started!!.session.createQueue(queueName, RoutingType.ANYCAST, queueName, true) + + val clientConnectionChangeStatus = client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus) + + if (expectedConnectedStatus) { + val msg = client.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap()) + client.write(msg) + assertThat(msg.onComplete.get()).isEqualTo(MessageStatus.Acknowledged) + } + + return clientConnectionChangeStatus } - private fun createArtemisServerAndClient(legalName: CordaX500Name, - crlCheckSoftFail: Boolean, - crlCheckArtemisServer: Boolean, - nodeCrlDistPoint: String, - sslHandshakeTimeout: Duration?): Pair { - val baseDirectory = temporaryFolder.root.toPath() / "artemis" + private fun startArtemisNode(legalName: CordaX500Name, + crlCheckSoftFail: Boolean, + crlDistPoints: CrlDistPoints, + distPointCrlSource: CertDistPointCrlSource, + sslHandshakeTimeout: Duration?, + remotingThreads: Int?): X509Certificate { + check(!::artemisNode.isInitialized) + val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val certificatesDirectory = baseDirectory / "certificates" val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, sslHandshakeTimeout = sslHandshakeTimeout) @@ -482,60 +532,34 @@ class CertificateRevocationListNodeTests { doReturn(crlCheckArtemisServer).whenever(it).crlCheckArtemisServer } artemisConfig.configureWithDevSSLCertificate() - recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, null) + val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer) val server = ArtemisMessagingServer( artemisConfig, artemisConfig.p2pAddress, MAX_MESSAGE_SIZE, threadPoolName = "${legalName.organisation}-server", - trace = true + trace = true, + distPointCrlSource = distPointCrlSource, + remotingThreads = remotingThreads ) val client = ArtemisMessagingClient( artemisConfig.p2pSslOptions, artemisConfig.p2pAddress, MAX_MESSAGE_SIZE, - threadPoolName = "${legalName.organisation}-client", - trace = true + threadPoolName = "${legalName.organisation}-client" ) server.start() client.start() - return server to client + val artemisNode = ArtemisNode(server, client) + this.artemisNode = artemisNode + return nodeCert } - private fun verifyArtemisConnection(crlCheckSoftFail: Boolean, - crlCheckArtemisServer: Boolean, - expectedConnected: Boolean = true, - expectedStatus: MessageStatus? = null, - revokedNodeCert: Boolean = false, - nodeCrlDistPoint: String = "http://${crlServer.hostAndPort}/crl/$NODE_CRL", - sslHandshakeTimeout: Duration? = null) { - val queueName = P2P_PREFIX + "Test" - val (artemisServer, artemisClient) = createArtemisServerAndClient( - CHARLIE_NAME, - crlCheckSoftFail, - crlCheckArtemisServer, - nodeCrlDistPoint, - sslHandshakeTimeout - ) - artemisServer.use { - artemisClient.started!!.session.createQueue(queueName, RoutingType.ANYCAST, queueName, true) - - val nodeCert = createAMQPClient(serverPort, true, ALICE_NAME, nodeCrlDistPoint) - if (revokedNodeCert) { - crlServer.revokedNodeCerts.add(nodeCert.serialNumber) - } - val clientConnected = amqpClient.onConnection.toFuture() - amqpClient.start() - val clientConnect = clientConnected.get() - assertThat(clientConnect.connected).isEqualTo(expectedConnected) - - if (expectedConnected) { - val msg = amqpClient.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap()) - amqpClient.write(msg) - assertEquals(expectedStatus, msg.onComplete.get()) - } - artemisClient.stop() + private class ArtemisNode(val server: ArtemisMessagingServer, val client: ArtemisMessagingClient) : Closeable { + override fun close() { + client.stop() + server.close() } } } diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt index c69538aae5..89672379d7 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt @@ -4,12 +4,15 @@ import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.whenever import io.netty.channel.EventLoopGroup import io.netty.channel.nio.NioEventLoopGroup +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.core.crypto.newSecureRandom import net.corda.core.identity.CordaX500Name import net.corda.core.internal.div import net.corda.core.toFuture import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.contextLogger +import net.corda.coretesting.internal.rigorousMock +import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.node.services.messaging.ArtemisMessagingServer @@ -23,7 +26,9 @@ import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer -import net.corda.nodeapi.internal.protonwrapper.netty.init +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.BOB_NAME @@ -31,9 +36,6 @@ import net.corda.testing.core.CHARLIE_NAME import net.corda.testing.core.MAX_MESSAGE_SIZE import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.internal.createDevIntermediateCaCertPath -import net.corda.coretesting.internal.rigorousMock -import net.corda.coretesting.internal.stubs.CertificateStoreStubs -import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig import org.apache.activemq.artemis.api.core.RoutingType import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Assert.assertArrayEquals @@ -42,7 +44,11 @@ import org.junit.Test import org.junit.rules.TemporaryFolder import java.security.cert.X509Certificate import java.util.concurrent.TimeUnit -import javax.net.ssl.* +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLHandshakeException +import javax.net.ssl.SSLParameters +import javax.net.ssl.SSLServerSocket +import javax.net.ssl.SSLSocket import kotlin.concurrent.thread import kotlin.test.assertEquals import kotlin.test.assertTrue @@ -144,15 +150,10 @@ class ProtonWrapperTests { sslConfig.keyStore.get(true).also { it.registerDevP2pCertificates(ALICE_NAME, rootCa.certificate, intermediateCa) } sslConfig.createTrustStore(rootCa.certificate) - val keyStore = sslConfig.keyStore.get() - val trustStore = sslConfig.trustStore.get() - val context = SSLContext.getInstance("TLS") - val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - keyManagerFactory.init(keyStore) + val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get()) val keyManagers = keyManagerFactory.keyManagers - val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustMgrFactory.init(trustStore) + val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get()) val trustManagers = trustMgrFactory.trustManagers context.init(keyManagers, trustManagers, newSecureRandom()) @@ -339,7 +340,7 @@ class ProtonWrapperTests { amqpServer.use { val connectionEvents = amqpServer.onConnection.toBlocking().iterator amqpServer.start() - val sharedThreads = NioEventLoopGroup() + val sharedThreads = NioEventLoopGroup(DefaultThreadFactory("sharedThreads")) val amqpClient1 = createSharedThreadsClient(sharedThreads, 0) val amqpClient2 = createSharedThreadsClient(sharedThreads, 1) amqpClient1.start() diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 32a7c15ef2..7be9b94dc9 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -4,7 +4,6 @@ import co.paralleluniverse.fibers.instrument.Retransform import com.codahale.metrics.MetricRegistry import com.google.common.collect.MutableClassToInstanceMap import com.google.common.util.concurrent.MoreExecutors -import com.google.common.util.concurrent.ThreadFactoryBuilder import com.zaxxer.hikari.pool.HikariPool import net.corda.common.logging.errorReporting.NodeDatabaseErrors import net.corda.confidential.SwapIdentitiesFlow @@ -67,6 +66,7 @@ import net.corda.core.toFuture import net.corda.core.transactions.LedgerTransaction import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.days +import net.corda.core.utilities.millis import net.corda.core.utilities.minutes import net.corda.djvm.source.ApiSource import net.corda.djvm.source.EmptyApi @@ -166,6 +166,7 @@ import net.corda.nodeapi.internal.persistence.RestrictedEntityManager import net.corda.nodeapi.internal.persistence.SchemaMigration import net.corda.nodeapi.internal.persistence.contextDatabase import net.corda.nodeapi.internal.persistence.withoutDatabaseAccess +import net.corda.nodeapi.internal.namedThreadPoolExecutor import net.corda.tools.shell.InteractiveShell import org.apache.activemq.artemis.utils.ReusableLatch import org.jolokia.jvmagent.JolokiaServer @@ -181,9 +182,6 @@ import java.time.format.DateTimeParseException import java.util.Properties import java.util.concurrent.ExecutorService import java.util.concurrent.Executors -import java.util.concurrent.LinkedBlockingQueue -import java.util.concurrent.ThreadPoolExecutor -import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit.MINUTES import java.util.concurrent.TimeUnit.SECONDS import java.util.function.Consumer @@ -881,13 +879,12 @@ abstract class AbstractNode(val configuration: NodeConfiguration, } // Start with 1 thread and scale up to the configured thread pool size if needed // Parameters of [ThreadPoolExecutor] based on [Executors.newFixedThreadPool] - return ThreadPoolExecutor( - 1, - numberOfThreads, - 0L, - TimeUnit.MILLISECONDS, - LinkedBlockingQueue(), - ThreadFactoryBuilder().setNameFormat("flow-external-operation-thread").setDaemon(true).build() + return namedThreadPoolExecutor( + corePoolSize = 1, + maxPoolSize = numberOfThreads, + idleKeepAlive = 0.millis, + poolName = "flow-external-operation-thread", + daemonThreads = true ) } diff --git a/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt b/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt index c146629364..9658fe2e53 100644 --- a/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt +++ b/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt @@ -135,12 +135,12 @@ class BrokerJaasLoginModule : BaseBrokerJaasLoginModule() { Pair(ArtemisMessagingComponent.NODE_RPC_USER, listOf(RolePrincipal(NODE_RPC_ROLE))) } ArtemisMessagingComponent.PEER_USER -> { - requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." } + val p2pJaasConfig = requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." } requireTls(certificates) // This check is redundant as it was performed already during the SSL handshake - CertificateChainCheckPolicy.RootMustMatch.createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates!!) - CertificateChainCheckPolicy.RevocationCheck(p2pJaasConfig!!.revocationMode) - .createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates) + CertificateChainCheckPolicy.RootMustMatch + .createCheck(p2pJaasConfig.keyStore, p2pJaasConfig.trustStore) + .checkCertificateChain(certificates!!) Pair(certificates.first().subjectDN.name, listOf(RolePrincipal(PEER_ROLE))) } else -> { diff --git a/node/src/main/kotlin/net/corda/node/internal/artemis/CertificateChainCheckPolicy.kt b/node/src/main/kotlin/net/corda/node/internal/artemis/CertificateChainCheckPolicy.kt index db14a63965..de1ac38bc8 100644 --- a/node/src/main/kotlin/net/corda/node/internal/artemis/CertificateChainCheckPolicy.kt +++ b/node/src/main/kotlin/net/corda/node/internal/artemis/CertificateChainCheckPolicy.kt @@ -2,17 +2,9 @@ package net.corda.node.internal.artemis import net.corda.core.identity.CordaX500Name import net.corda.core.utilities.contextLogger -import net.corda.nodeapi.internal.crypto.X509CertificateFactory import net.corda.nodeapi.internal.crypto.X509Utilities -import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig -import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl -import net.corda.nodeapi.internal.protonwrapper.netty.certPathToString import java.security.KeyStore -import java.security.cert.CertPathValidator -import java.security.cert.CertPathValidatorException import java.security.cert.CertificateException -import java.security.cert.PKIXBuilderParameters -import java.security.cert.X509CertSelector sealed class CertificateChainCheckPolicy { companion object { @@ -92,33 +84,4 @@ sealed class CertificateChainCheckPolicy { } } } - - class RevocationCheck(val revocationConfig: RevocationConfig) : CertificateChainCheckPolicy() { - constructor(revocationMode: RevocationConfig.Mode) : this(RevocationConfigImpl(revocationMode)) - - override fun createCheck(keyStore: KeyStore, trustStore: KeyStore): Check { - return object : Check { - override fun checkCertificateChain(theirChain: Array) { - // Convert javax.security.cert.X509Certificate to java.security.cert.X509Certificate. - val chain = theirChain.map { X509CertificateFactory().generateCertificate(it.encoded.inputStream()) } - log.info("Check Client Certpath:\r\n${certPathToString(chain.toTypedArray())}") - - // Drop the last certificate which must be a trusted root (validated by RootMustMatch). - // Assume that there is no more trusted roots (or corresponding public keys) in the remaining chain. - // See PKIXValidator.engineValidate() for reference implementation. - val certPath = X509Utilities.buildCertPath(chain.dropLast(1)) - val certPathValidator = CertPathValidator.getInstance("PKIX") - val pkixRevocationChecker = revocationConfig.createPKIXRevocationChecker() - val params = PKIXBuilderParameters(trustStore, X509CertSelector()) - params.addCertPathChecker(pkixRevocationChecker) - try { - certPathValidator.validate(certPath, params) - } catch (ex: CertPathValidatorException) { - log.error("Bad certificate path", ex) - throw ex - } - } - } - } - } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt index d1e0a7ea16..aef1c9820c 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt @@ -7,9 +7,16 @@ import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug -import net.corda.node.internal.artemis.* +import net.corda.node.internal.artemis.ArtemisBroker +import net.corda.node.internal.artemis.BrokerAddresses +import net.corda.node.internal.artemis.BrokerJaasLoginModule import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.NODE_P2P_ROLE import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.PEER_ROLE +import net.corda.node.internal.artemis.NodeJaasConfig +import net.corda.node.internal.artemis.P2PJaasConfig +import net.corda.node.internal.artemis.SecureArtemisConfiguration +import net.corda.node.internal.artemis.UserValidationPlugin +import net.corda.node.internal.artemis.isBindingError import net.corda.node.services.config.NodeConfiguration import net.corda.nodeapi.internal.AmqpMessageSizeChecksInterceptor import net.corda.nodeapi.internal.ArtemisMessageSizeChecksInterceptor @@ -20,7 +27,10 @@ import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.SECURITY_INVALIDATION_INTERVAL import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pAcceptorTcpTransport import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl +import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation import net.corda.nodeapi.internal.requireOnDefaultFileSystem +import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl @@ -33,7 +43,6 @@ import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager import java.io.IOException import java.lang.Long.max -import java.security.KeyStoreException import javax.annotation.concurrent.ThreadSafe import javax.security.auth.login.AppConfigurationEntry import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED @@ -57,7 +66,9 @@ class ArtemisMessagingServer(private val config: NodeConfiguration, private val maxMessageSize: Int, private val journalBufferTimeout : Int? = null, private val threadPoolName: String = "ArtemisServer", - private val trace: Boolean = false) : ArtemisBroker, SingletonSerializeAsToken() { + private val trace: Boolean = false, + private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON, + private val remotingThreads: Int? = null) : ArtemisBroker, SingletonSerializeAsToken() { companion object { private val log = contextLogger() } @@ -91,9 +102,7 @@ class ArtemisMessagingServer(private val config: NodeConfiguration, override val started: Boolean get() = activeMQServer.isStarted - // TODO: Maybe wrap [IOException] on a key store load error so that it's clearly splitting key store loading from - // Artemis IO errors - @Throws(IOException::class, AddressBindingException::class, KeyStoreException::class) + @Suppress("ThrowsCount") private fun configureAndStartServer() { val artemisConfig = createArtemisConfig() val securityManager = createArtemisSecurityManager() @@ -133,11 +142,23 @@ class ArtemisMessagingServer(private val config: NodeConfiguration, // The transaction cache is configurable, and drives other cache sizes. globalMaxSize = max(config.transactionCacheSizeBytes, 10L * maxMessageSize) + val revocationMode = if (config.crlCheckArtemisServer) { + if (config.crlCheckSoftFail) RevocationConfig.Mode.SOFT_FAIL else RevocationConfig.Mode.HARD_FAIL + } else { + RevocationConfig.Mode.OFF + } + val trustManagerFactory = trustManagerFactoryWithRevocation( + config.p2pSslOptions.trustStore.get(), + RevocationConfigImpl(revocationMode), + distPointCrlSource + ) addAcceptorConfiguration(p2pAcceptorTcpTransport( NetworkHostAndPort(messagingServerAddress.host, messagingServerAddress.port), config.p2pSslOptions, + trustManagerFactory, threadPoolName = threadPoolName, - trace = trace + trace = trace, + remotingThreads = remotingThreads )) // Enable built in message deduplication. Note we still have to do our own as the delayed commits // and our own definition of commit mean that the built in deduplication cannot remove all duplicates. diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt index d8ba1ddbb2..a37fc6abbf 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt @@ -5,13 +5,24 @@ import io.netty.channel.ChannelHandlerContext import io.netty.channel.group.ChannelGroup import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LoggingHandler +import io.netty.handler.ssl.SslContext +import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.ssl.SslHandler import io.netty.handler.ssl.SslHandshakeTimeoutException +import io.netty.handler.ssl.SslProvider import net.corda.core.internal.declaredField import net.corda.core.utilities.contextLogger import net.corda.nodeapi.internal.ArtemisTcpTransport +import net.corda.nodeapi.internal.config.CertificateStore +import net.corda.nodeapi.internal.protonwrapper.netty.createAndInitSslContext +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor +import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration import org.apache.activemq.artemis.api.core.BaseInterceptor import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor +import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants +import org.apache.activemq.artemis.core.remoting.impl.ssl.SSLSupport +import org.apache.activemq.artemis.core.server.ActiveMQServerLogger import org.apache.activemq.artemis.core.server.cluster.ClusterConnection import org.apache.activemq.artemis.spi.core.protocol.ProtocolManager import org.apache.activemq.artemis.spi.core.remoting.Acceptor @@ -21,13 +32,19 @@ import org.apache.activemq.artemis.spi.core.remoting.ServerConnectionLifeCycleLi import org.apache.activemq.artemis.utils.ConfigurationHelper import org.apache.activemq.artemis.utils.actors.OrderedExecutor import java.nio.channels.ClosedChannelException +import java.nio.file.Paths +import java.security.PrivilegedExceptionAction import java.time.Duration import java.util.concurrent.Executor import java.util.concurrent.ScheduledExecutorService import java.util.regex.Pattern +import javax.net.ssl.KeyManagerFactory +import javax.net.ssl.SSLContext import javax.net.ssl.SSLEngine +import javax.net.ssl.TrustManagerFactory +import javax.security.auth.Subject -@Suppress("unused") // Used via reflection in ArtemisTcpTransport +@Suppress("unused", "TooGenericExceptionCaught", "ComplexMethod", "MagicNumber", "TooManyFunctions") class NodeNettyAcceptorFactory : AcceptorFactory { override fun createAcceptor(name: String?, clusterConnection: ClusterConnection?, @@ -57,6 +74,7 @@ class NodeNettyAcceptorFactory : AcceptorFactory { } private val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "NodeNettyAcceptor", configuration) + private val sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName) private val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration) @Synchronized @@ -70,11 +88,17 @@ class NodeNettyAcceptorFactory : AcceptorFactory { } } + @Synchronized + override fun stop() { + super.stop() + sslDelegatedTaskExecutor.shutdown() + } + @Synchronized override fun getSslHandler(alloc: ByteBufAllocator?): SslHandler { applyThreadPoolName() - val engine = super.getSslHandler(alloc).engine() - val sslHandler = NodeAcceptorSslHandler(engine, trace) + val engine = getSSLEngine(alloc) + val sslHandler = NodeAcceptorSslHandler(engine, sslDelegatedTaskExecutor, trace) val handshakeTimeout = configuration[ArtemisTcpTransport.SSL_HANDSHAKE_TIMEOUT_NAME] as Duration? if (handshakeTimeout != null) { sslHandler.handshakeTimeoutMillis = handshakeTimeout.toMillis() @@ -91,10 +115,117 @@ class NodeNettyAcceptorFactory : AcceptorFactory { Thread.currentThread().name = "$threadPoolName-${matcher.group(1)}" // Preserve the pool thread number } } + + /** + * This is a copy of [NettyAcceptor.getSslHandler] so that we can provide different implementations for [loadOpenSslEngine] and + * [loadJdkSslEngine]. [NodeNettyAcceptor], instead of creating a default [TrustManagerFactory], will simply use the provided one in + * the [ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] configuration. + */ + private fun getSSLEngine(alloc: ByteBufAllocator?): SSLEngine { + val engine = if (sslProvider == TransportConstants.OPENSSL_PROVIDER) { + loadOpenSslEngine(alloc) + } else { + loadJdkSslEngine() + } + engine.useClientMode = false + if (needClientAuth) { + engine.needClientAuth = true + } + + // setting the enabled cipher suites resets the enabled protocols so we need + // to save the enabled protocols so that after the customer cipher suite is enabled + // we can reset the enabled protocols if a customer protocol isn't specified + val originalProtocols = engine.enabledProtocols + if (enabledCipherSuites != null) { + try { + engine.enabledCipherSuites = SSLSupport.parseCommaSeparatedListIntoArray(enabledCipherSuites) + } catch (e: IllegalArgumentException) { + ActiveMQServerLogger.LOGGER.invalidCipherSuite(SSLSupport.parseArrayIntoCommandSeparatedList(engine.supportedCipherSuites)) + throw e + } + } + if (enabledProtocols != null) { + try { + engine.enabledProtocols = SSLSupport.parseCommaSeparatedListIntoArray(enabledProtocols) + } catch (e: IllegalArgumentException) { + ActiveMQServerLogger.LOGGER.invalidProtocol(SSLSupport.parseArrayIntoCommandSeparatedList(engine.supportedProtocols)) + throw e + } + } else { + engine.enabledProtocols = originalProtocols + } + return engine + } + + /** + * Copy of [NettyAcceptor.loadOpenSslEngine] which invokes our custom [createOpenSslContext]. + */ + private fun loadOpenSslEngine(alloc: ByteBufAllocator?): SSLEngine { + val context = try { + // We copied all this code just so we could replace the SSLSupport.createNettyContext method call with our own one. + createOpenSslContext() + } catch (e: Exception) { + throw IllegalStateException("Unable to create NodeNettyAcceptor", e) + } + return Subject.doAs(null, PrivilegedExceptionAction { + context.newEngine(alloc) + }) + } + + /** + * Copy of [NettyAcceptor.loadJdkSslEngine] which invokes our custom [createJdkSSLContext]. + */ + private fun loadJdkSslEngine(): SSLEngine { + val context = try { + // We copied all this code just so we could replace the SSLHelper.createContext method call with our own one. + createJdkSSLContext() + } catch (e: Exception) { + throw IllegalStateException("Unable to create NodeNettyAcceptor", e) + } + return Subject.doAs(null, PrivilegedExceptionAction { + context.createSSLEngine() + }) + } + + /** + * Create an [SSLContext] using the [TrustManagerFactory] provided on the [ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] config. + */ + private fun createJdkSSLContext(): SSLContext { + return createAndInitSslContext( + createKeyManagerFactory(), + configuration[ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory? + ) + } + + /** + * Create an [SslContext] using the the [TrustManagerFactory] provided on the [ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] config. + */ + private fun createOpenSslContext(): SslContext { + return SslContextBuilder + .forServer(createKeyManagerFactory()) + .sslProvider(SslProvider.OPENSSL) + .trustManager(configuration[ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?) + .build() + } + + private fun createKeyManagerFactory(): KeyManagerFactory { + return keyManagerFactory(CertificateStore.fromFile(Paths.get(keyStorePath), keyStorePassword, keyStorePassword, false)) + } + + // Replicate the fields which are private in NettyAcceptor + private val sslProvider = ConfigurationHelper.getStringProperty(TransportConstants.SSL_PROVIDER, TransportConstants.DEFAULT_SSL_PROVIDER, configuration) + private val needClientAuth = ConfigurationHelper.getBooleanProperty(TransportConstants.NEED_CLIENT_AUTH_PROP_NAME, TransportConstants.DEFAULT_NEED_CLIENT_AUTH, configuration) + private val enabledCipherSuites = ConfigurationHelper.getStringProperty(TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME, TransportConstants.DEFAULT_ENABLED_CIPHER_SUITES, configuration) + private val enabledProtocols = ConfigurationHelper.getStringProperty(TransportConstants.ENABLED_PROTOCOLS_PROP_NAME, TransportConstants.DEFAULT_ENABLED_PROTOCOLS, configuration) + private val keyStorePath = ConfigurationHelper.getStringProperty(TransportConstants.KEYSTORE_PATH_PROP_NAME, TransportConstants.DEFAULT_KEYSTORE_PATH, configuration) + private val keyStoreProvider = ConfigurationHelper.getStringProperty(TransportConstants.KEYSTORE_PROVIDER_PROP_NAME, TransportConstants.DEFAULT_KEYSTORE_PROVIDER, configuration) + private val keyStorePassword = ConfigurationHelper.getPasswordProperty(TransportConstants.KEYSTORE_PASSWORD_PROP_NAME, TransportConstants.DEFAULT_KEYSTORE_PASSWORD, configuration, ActiveMQDefaultConfiguration.getPropMaskPassword(), ActiveMQDefaultConfiguration.getPropPasswordCodec()) } - private class NodeAcceptorSslHandler(engine: SSLEngine, private val trace: Boolean) : SslHandler(engine) { + private class NodeAcceptorSslHandler(engine: SSLEngine, + delegatedTaskExecutor: Executor, + private val trace: Boolean) : SslHandler(engine, delegatedTaskExecutor) { companion object { private val logger = contextLogger() } diff --git a/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt index d8c320cab4..e79f485c01 100644 --- a/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt @@ -30,7 +30,7 @@ internal class RpcBrokerConfiguration(baseDirectory: Path, maxMessageSize: Int, setDirectories(baseDirectory) val acceptorConfigurationsSet = mutableSetOf( - rpcAcceptorTcpTransport(address, sslOptions, useSsl) + rpcAcceptorTcpTransport(address, sslOptions, enableSSL = useSsl) ) adminAddress?.let { acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration) diff --git a/testing/core-test-utils/src/main/kotlin/net/corda/testing/core/TestUtils.kt b/testing/core-test-utils/src/main/kotlin/net/corda/testing/core/TestUtils.kt index 1bcf1ac389..175f52f5ae 100644 --- a/testing/core-test-utils/src/main/kotlin/net/corda/testing/core/TestUtils.kt +++ b/testing/core-test-utils/src/main/kotlin/net/corda/testing/core/TestUtils.kt @@ -1,30 +1,50 @@ -@file:Suppress("UNUSED_PARAMETER") @file:JvmName("TestUtils") +@file:Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod", "LongParameterList") package net.corda.testing.core import net.corda.core.contracts.PartyAndReference import net.corda.core.contracts.StateRef -import net.corda.core.crypto.* +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.SignatureScheme +import net.corda.core.crypto.toStringShort import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate +import net.corda.core.internal.toX500Name import net.corda.core.internal.unspecifiedCountry import net.corda.core.node.NodeInfo import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.millis +import net.corda.core.utilities.minutes +import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA +import net.corda.coretesting.internal.DEV_ROOT_CA import net.corda.nodeapi.internal.createDevNodeCa import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.CertificateType import net.corda.nodeapi.internal.crypto.X509Utilities -import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA -import net.corda.coretesting.internal.DEV_ROOT_CA +import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames +import org.bouncycastle.asn1.x509.CRLReason +import org.bouncycastle.asn1.x509.DistributionPointName +import org.bouncycastle.asn1.x509.Extension +import org.bouncycastle.asn1.x509.ExtensionsGenerator +import org.bouncycastle.asn1.x509.GeneralName +import org.bouncycastle.asn1.x509.GeneralNames +import org.bouncycastle.asn1.x509.IssuingDistributionPoint +import org.bouncycastle.cert.jcajce.JcaX509CRLConverter +import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils +import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder import java.math.BigInteger +import java.net.URI import java.security.KeyPair import java.security.PublicKey +import java.security.cert.X509CRL import java.security.cert.X509Certificate import java.time.Duration import java.time.Instant +import java.util.* import java.util.concurrent.atomic.AtomicInteger import kotlin.test.fail @@ -109,6 +129,44 @@ fun getTestPartyAndCertificate(name: CordaX500Name, publicKey: PublicKey): Party return getTestPartyAndCertificate(Party(name, publicKey)) } +fun createCRL(issuer: CertificateAndKeyPair, + revokedCerts: List, + issuingDistPoint: URI? = null, + thisUpdate: Instant = Instant.now(), + nextUpdate: Instant = thisUpdate + 5.minutes, + indirect: Boolean = false, + revocationDate: Instant = thisUpdate, + crlReason: Int = CRLReason.keyCompromise, + signatureAlgorithm: String = "SHA256withECDSA"): X509CRL { + val builder = JcaX509v2CRLBuilder(issuer.certificate.subjectX500Principal, Date.from(thisUpdate)) + val extensionUtils = JcaX509ExtensionUtils() + builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(issuer.certificate)) + // This is required and needs to match the certificate settings with respect to being indirect + builder.addExtension( + Extension.issuingDistributionPoint, + true, + IssuingDistributionPoint( + issuingDistPoint?.let { DistributionPointName(toGeneralNames(it.toString(), GeneralName.uniformResourceIdentifier)) }, + indirect, + false + ) + ) + builder.setNextUpdate(Date.from(nextUpdate)) + for (revokedCert in revokedCerts) { + val extensionsGenerator = ExtensionsGenerator() + extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(crlReason)) + // Certificate issuer is required for indirect CRL + extensionsGenerator.addExtension( + Extension.certificateIssuer, + true, + GeneralNames(GeneralName(revokedCert.issuerX500Principal.toX500Name())) + ) + builder.addCRLEntry(revokedCert.serialNumber, Date.from(revocationDate), extensionsGenerator.generate()) + } + val bcProvider = Crypto.findProvider("BC") + val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(bcProvider).build(issuer.keyPair.private) + return JcaX509CRLConverter().setProvider(bcProvider).getCRL(builder.build(signer)) +} private val count = AtomicInteger(0) /** @@ -188,7 +246,6 @@ fun NodeInfo.singleIdentity(): Party = singleIdentityAndCert().party * The above will test our expectation that the getWaitingFlows action was executed successfully considering * that it may take a few hundreds of milliseconds for the flow state machine states to settle. */ -@Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod") fun executeTest( timeout: Duration, cleanup: (() -> Unit)? = null, diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/CrlServer.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/CrlServer.kt index 744f33c72c..b6dee805fa 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/CrlServer.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/CrlServer.kt @@ -4,30 +4,26 @@ package net.corda.testing.node.internal.network import net.corda.core.crypto.Crypto import net.corda.core.internal.CertRole +import net.corda.core.internal.toX500Name import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.contextLogger import net.corda.core.utilities.days import net.corda.core.utilities.minutes -import net.corda.core.utilities.seconds import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA import net.corda.coretesting.internal.DEV_ROOT_CA import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.ContentSignerBuilder import net.corda.nodeapi.internal.crypto.X509Utilities +import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames import net.corda.nodeapi.internal.crypto.certificateType import net.corda.nodeapi.internal.crypto.toJca -import org.bouncycastle.asn1.x500.X500Name +import net.corda.testing.core.createCRL import org.bouncycastle.asn1.x509.CRLDistPoint import org.bouncycastle.asn1.x509.DistributionPoint import org.bouncycastle.asn1.x509.DistributionPointName import org.bouncycastle.asn1.x509.Extension import org.bouncycastle.asn1.x509.GeneralName import org.bouncycastle.asn1.x509.GeneralNames -import org.bouncycastle.asn1.x509.IssuingDistributionPoint -import org.bouncycastle.asn1.x509.ReasonFlags -import org.bouncycastle.cert.jcajce.JcaX509CRLConverter -import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils -import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder -import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.ServerConnector import org.eclipse.jetty.server.handler.HandlerCollection @@ -36,11 +32,12 @@ import org.eclipse.jetty.servlet.ServletHolder import org.glassfish.jersey.server.ResourceConfig import org.glassfish.jersey.servlet.ServletContainer import java.io.Closeable -import java.math.BigInteger import java.net.InetSocketAddress +import java.net.URI import java.security.KeyPair import java.security.cert.X509CRL import java.security.cert.X509Certificate +import java.time.Duration import java.util.* import javax.security.auth.x500.X500Principal import javax.ws.rs.GET @@ -51,7 +48,7 @@ import kotlin.collections.ArrayList class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { companion object { - private const val SIGNATURE_ALGORITHM = "SHA256withECDSA" + private val logger = contextLogger() const val NODE_CRL = "node.crl" const val FORBIDDEN_CRL = "forbidden.crl" @@ -72,8 +69,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { null ) if (crlDistPoint != null) { - val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint))) - val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(X500Name.getInstance(it.encoded))) } + val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier)) + val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(it.toX500Name())) } val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames) builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint))) } @@ -87,14 +84,17 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { } } - val revokedNodeCerts: MutableList = ArrayList() - val revokedIntermediateCerts: MutableList = ArrayList() + val revokedNodeCerts: MutableList = ArrayList() + val revokedIntermediateCerts: MutableList = ArrayList() val rootCa: CertificateAndKeyPair = DEV_ROOT_CA private lateinit var _intermediateCa: CertificateAndKeyPair val intermediateCa: CertificateAndKeyPair get() = _intermediateCa + @Volatile + var delay: Duration? = null + val hostAndPort: NetworkHostAndPort get() = server.connectors.mapNotNull { it as? ServerConnector } .map { NetworkHostAndPort(it.host, it.localPort) } @@ -106,7 +106,7 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { DEV_INTERMEDIATE_CA.certificate.withCrlDistPoint(rootCa.keyPair, "http://$hostAndPort/crl/$INTERMEDIATE_CRL"), DEV_INTERMEDIATE_CA.keyPair ) - println("Network management web services started on $hostAndPort") + logger.info("Network management web services started on $hostAndPort") } fun replaceNodeCertDistPoint(nodeCaCert: X509Certificate, @@ -115,29 +115,20 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { return nodeCaCert.withCrlDistPoint(intermediateCa.keyPair, nodeCaCrlDistPoint, crlIssuer) } - fun createRevocationList(signatureAlgorithm: String, - ca: CertificateAndKeyPair, - endpoint: String, - indirect: Boolean, - serialNumbers: List): X509CRL { - println("Generating CRL for $endpoint") - val builder = JcaX509v2CRLBuilder(ca.certificate.subjectX500Principal, Date(System.currentTimeMillis() - 1.minutes.toMillis())) - val extensionUtils = JcaX509ExtensionUtils() - builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(ca.certificate)) - val issuingDistPointName = GeneralName(GeneralName.uniformResourceIdentifier, "http://$hostAndPort/crl/$endpoint") - // This is required and needs to match the certificate settings with respect to being indirect - val issuingDistPoint = IssuingDistributionPoint(DistributionPointName(GeneralNames(issuingDistPointName)), indirect, false) - builder.addExtension(Extension.issuingDistributionPoint, true, issuingDistPoint) - builder.setNextUpdate(Date(System.currentTimeMillis() + 1.seconds.toMillis())) - serialNumbers.forEach { - builder.addCRLEntry(it, Date(System.currentTimeMillis() - 10.minutes.toMillis()), ReasonFlags.certificateHold) - } - val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(Crypto.findProvider("BC")).build(ca.keyPair.private) - return JcaX509CRLConverter().setProvider(Crypto.findProvider("BC")).getCRL(builder.build(signer)) + private fun createServerCRL(issuer: CertificateAndKeyPair, + endpoint: String, + indirect: Boolean, + revokedCerts: List): X509CRL { + logger.info("Generating CRL for /$endpoint: ${revokedCerts.map { it.serialNumber }}") + return createCRL( + issuer, + revokedCerts, + issuingDistPoint = URI("http://$hostAndPort/crl/$endpoint"), + indirect = indirect + ) } override fun close() { - println("Shutting down network management web services...") server.stop() server.join() } @@ -159,8 +150,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { @Path(NODE_CRL) @Produces("application/pkcs7-crl") fun getNodeCRL(): Response { - return Response.ok(crlServer.createRevocationList( - SIGNATURE_ALGORITHM, + crlServer.delay?.toMillis()?.let(Thread::sleep) + return Response.ok(crlServer.createServerCRL( crlServer.intermediateCa, NODE_CRL, false, @@ -179,8 +170,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { @Path(INTERMEDIATE_CRL) @Produces("application/pkcs7-crl") fun getIntermediateCRL(): Response { - return Response.ok(crlServer.createRevocationList( - SIGNATURE_ALGORITHM, + crlServer.delay?.toMillis()?.let(Thread::sleep) + return Response.ok(crlServer.createServerCRL( crlServer.rootCa, INTERMEDIATE_CRL, false, @@ -192,11 +183,11 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { @Path(EMPTY_CRL) @Produces("application/pkcs7-crl") fun getEmptyCRL(): Response { - return Response.ok(crlServer.createRevocationList( - SIGNATURE_ALGORITHM, + return Response.ok(crlServer.createServerCRL( crlServer.rootCa, EMPTY_CRL, - true, emptyList() + true, + emptyList() ).encoded).build() } } diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalTestUtils.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalTestUtils.kt index 1dbf7249cb..ab080c3d7e 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalTestUtils.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/InternalTestUtils.kt @@ -42,6 +42,7 @@ import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.SchemaMigration +import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.serialization.internal.amqp.AMQP_ENABLED import net.corda.testing.core.ALICE_NAME @@ -52,6 +53,8 @@ import java.io.IOException import java.net.ServerSocket import java.nio.file.Path import java.security.KeyPair +import java.security.cert.X509CRL +import java.security.cert.X509Certificate import java.util.* import java.util.jar.JarOutputStream import java.util.jar.Manifest @@ -147,6 +150,12 @@ fun p2pSslOptions(path: Path, name: CordaX500Name = CordaX500Name("MegaCorp", "L return sslConfig } +fun fixedCrlSource(crls: Set): CrlSource { + return object : CrlSource { + override fun fetch(certificate: X509Certificate): Set = crls + } +} + /** This is the same as the deprecated [WireTransaction] c'tor but avoids the deprecation warning. */ @SuppressWarnings("LongParameterList") fun createWireTransaction(inputs: List, From 0cc3ffe1d63c23d0614ddf37e76bc9fbcdbb9d7f Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Tue, 16 May 2023 08:45:01 +0100 Subject: [PATCH 02/12] ENT-9941: Moved new connector factory to node-api (#7369) --- .../kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt | 2 +- .../net/corda/nodeapi/internal}/NodeNettyConnectorFactory.kt | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) rename {node/src/main/kotlin/net/corda/node/services/messaging => node-api/src/main/kotlin/net/corda/nodeapi/internal}/NodeNettyConnectorFactory.kt (96%) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt index 6183cfe818..4f5d05531f 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt @@ -182,7 +182,7 @@ class ArtemisTcpTransport { threadPoolName: String, trace: Boolean): TransportConfiguration { return createTransport( - "net.corda.node.services.messaging.NodeNettyConnectorFactory", + NodeNettyConnectorFactory::class.java.name, hostAndPort, protocols, options, diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyConnectorFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeNettyConnectorFactory.kt similarity index 96% rename from node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyConnectorFactory.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeNettyConnectorFactory.kt index 1672031a11..47e046566e 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyConnectorFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeNettyConnectorFactory.kt @@ -1,9 +1,8 @@ -package net.corda.node.services.messaging +package net.corda.nodeapi.internal import io.netty.channel.ChannelPipeline import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LoggingHandler -import net.corda.nodeapi.internal.ArtemisTcpTransport import org.apache.activemq.artemis.core.protocol.core.impl.ActiveMQClientProtocolManager import org.apache.activemq.artemis.core.remoting.impl.netty.NettyConnector import org.apache.activemq.artemis.spi.core.remoting.BufferHandler @@ -15,7 +14,6 @@ import org.apache.activemq.artemis.utils.ConfigurationHelper import java.util.concurrent.Executor import java.util.concurrent.ScheduledExecutorService -@Suppress("unused") class NodeNettyConnectorFactory : ConnectorFactory { override fun createConnector(configuration: MutableMap?, handler: BufferHandler?, From c0650213284c1e382f681033d8ac675962a76d39 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Thu, 18 May 2023 11:33:05 +0100 Subject: [PATCH 03/12] ENT-8827: The ordering of vault query results is clobbered by ServiceHub.loadStates --- .../kotlin/net/corda/core/node/ServiceHub.kt | 2 +- .../internal/ServicesForResolutionImpl.kt | 25 +- .../node/messaging/TwoPartyTradeFlowTests.kt | 4 +- .../statemachine/FlowSoftLocksTests.kt | 1 - .../node/services/vault/VaultQueryTests.kt | 80 ++++--- .../testing/internal/vault/VaultFiller.kt | 217 ++++++++---------- 6 files changed, 168 insertions(+), 161 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt b/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt index d63b63edf4..612e341a6f 100644 --- a/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt +++ b/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt @@ -64,7 +64,7 @@ interface ServicesForResolution { /** * Given a [Set] of [StateRef]'s loads the referenced transaction and looks up the specified output [ContractState]. * - * @throws TransactionResolutionException if [stateRef] points to a non-existent transaction. + * @throws TransactionResolutionException if any of the [stateRefs] point to a non-existent transaction. */ // TODO: future implementation to use a Vault state ref -> contract state BLOB table and perform single query bulk load // as the existing transaction store will become encrypted at some point diff --git a/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt b/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt index f5836c0cc5..06e46992d4 100644 --- a/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt @@ -2,6 +2,7 @@ package net.corda.node.internal import net.corda.core.contracts.* import net.corda.core.cordapp.CordappProvider +import net.corda.core.crypto.SecureHash import net.corda.core.internal.SerializedStateAndRef import net.corda.core.node.NetworkParameters import net.corda.core.node.ServicesForResolution @@ -9,8 +10,10 @@ import net.corda.core.node.services.AttachmentStorage import net.corda.core.node.services.IdentityService import net.corda.core.node.services.NetworkParametersService import net.corda.core.node.services.TransactionStorage +import net.corda.core.transactions.BaseTransaction import net.corda.core.transactions.ContractUpgradeWireTransaction import net.corda.core.transactions.NotaryChangeWireTransaction +import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction.Companion.resolveStateRefBinaryComponent @@ -26,25 +29,23 @@ data class ServicesForResolutionImpl( @Throws(TransactionResolutionException::class) override fun loadState(stateRef: StateRef): TransactionState<*> { - val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash) - return stx.resolveBaseTransaction(this).outputs[stateRef.index] + return toBaseTransaction(stateRef.txhash).outputs[stateRef.index] } @Throws(TransactionResolutionException::class) override fun loadStates(stateRefs: Set): Set> { - return stateRefs.groupBy { it.txhash }.flatMap { - val stx = validatedTransactions.getTransaction(it.key) ?: throw TransactionResolutionException(it.key) - val baseTx = stx.resolveBaseTransaction(this) - it.value.map { ref -> StateAndRef(baseTx.outputs[ref.index], ref) } - }.toSet() + val baseTxs = HashMap() + return stateRefs.mapTo(LinkedHashSet()) { stateRef -> + val baseTx = baseTxs.computeIfAbsent(stateRef.txhash, ::toBaseTransaction) + StateAndRef(baseTx.outputs[stateRef.index], stateRef) + } } @Throws(TransactionResolutionException::class, AttachmentResolutionException::class) override fun loadContractAttachment(stateRef: StateRef): Attachment { // We may need to recursively chase transactions if there are notary changes. fun inner(stateRef: StateRef, forContractClassName: String?): Attachment { - val ctx = validatedTransactions.getTransaction(stateRef.txhash)?.coreTransaction - ?: throw TransactionResolutionException(stateRef.txhash) + val ctx = getSignedTransaction(stateRef.txhash).coreTransaction when (ctx) { is WireTransaction -> { val transactionState = ctx.outRef(stateRef.index).state @@ -69,4 +70,10 @@ data class ServicesForResolutionImpl( } return inner(stateRef, null) } + + private fun toBaseTransaction(txhash: SecureHash): BaseTransaction = getSignedTransaction(txhash).resolveBaseTransaction(this) + + private fun getSignedTransaction(txhash: SecureHash): SignedTransaction { + return validatedTransactions.getTransaction(txhash) ?: throw TransactionResolutionException(txhash) + } } diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index 0c49ee44ac..28a5f3b973 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -124,7 +124,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { bobNode.internals.disableDBCloseOnStop() bobNode.database.transaction { - VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, cashIssuer) + VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, cashIssuer, atMostThisManyStates = 10) } val alicesFakePaper = aliceNode.database.transaction { @@ -233,7 +233,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { val issuer = bank.ref(1, 2, 3) bobNode.database.transaction { - VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, issuer) + VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, issuer, atMostThisManyStates = 10) } val alicesFakePaper = aliceNode.database.transaction { fillUpForSeller(false, issuer, alice, diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt index 6f0fa3278c..1930e7ffd8 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowSoftLocksTests.kt @@ -244,7 +244,6 @@ class FlowSoftLocksTests { 100.DOLLARS, bankNode.services, thisManyStates, - thisManyStates, cashIssuer ) } diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index 09964b6602..1b139ab022 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -20,14 +20,13 @@ import net.corda.finance.* import net.corda.finance.contracts.CommercialPaper import net.corda.finance.contracts.Commodity import net.corda.finance.contracts.DealState -import net.corda.finance.workflows.asset.selection.AbstractCashSelection import net.corda.finance.contracts.asset.Cash import net.corda.finance.schemas.CashSchemaV1 -import net.corda.finance.schemas.CashSchemaV1.PersistentCashState import net.corda.finance.schemas.CommercialPaperSchemaV1 import net.corda.finance.test.SampleCashSchemaV2 import net.corda.finance.test.SampleCashSchemaV3 import net.corda.finance.workflows.CommercialPaperUtils +import net.corda.finance.workflows.asset.selection.AbstractCashSelection import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseTransaction @@ -197,8 +196,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } protected fun consumeCash(amount: Amount) = vaultFiller.consumeCash(amount, CHARLIE) - private fun setUpDb(_database: CordaPersistence, delay: Long = 0) { - _database.transaction { + + private fun setUpDb(database: CordaPersistence, delay: Long = 0) { + database.transaction { // create new states vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 10, DUMMY_CASH_ISSUER) val linearStatesXYZ = vaultFiller.fillWithSomeTestLinearStates(1, "XYZ") @@ -444,7 +444,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.state.data.linearNumber }.sortedBy { it.ref.txhash }.sortedBy { it.ref.index }).isEqualTo(allStates) } - (1..3).forEach { + repeat(3) { val newAllStates = vaultService.queryBy(sorting = sorting, criteria = criteria).states assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates) assertThat(newAllStates).containsExactlyElementsOf(allStates) @@ -485,7 +485,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.ref.txhash }.sortedByDescending { it.ref.index }).isEqualTo(allStates) } - (1..3).forEach { + repeat(3) { val newAllStates = vaultService.queryBy(sorting = sorting, criteria = criteria).states assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates) assertThat(newAllStates).containsExactlyElementsOf(allStates) @@ -638,7 +638,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } val sorted = results.states.sortedBy { it.ref.toString() } assertThat(results.states).isEqualTo(sorted) - assertThat(results.states).allSatisfy { !consumed.contains(it.ref.txhash) } + assertThat(results.states).allSatisfy { assertThat(consumed).doesNotContain(it.ref.txhash) } } } @@ -1537,7 +1537,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")) // count fungible assets val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } - val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count) + val countCriteria = VaultCustomQueryCriteria(count) val fungibleStateCount = vaultService.queryBy>(countCriteria).otherResults.single() as Long assertThat(fungibleStateCount).isEqualTo(10L) @@ -1563,7 +1563,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } // count fungible assets - val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) + val countCriteria = VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) val fungibleStateCount = vaultService.queryBy>(countCriteria).otherResults.single() as Long assertThat(fungibleStateCount).isEqualTo(10L) @@ -1583,7 +1583,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // UNCONSUMED states (default) // count fungible assets - val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED) + val countCriteriaUnconsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED) val fungibleStateCountUnconsumed = vaultService.queryBy>(countCriteriaUnconsumed).otherResults.single() as Long assertThat(fungibleStateCountUnconsumed.toInt()).isEqualTo(10 - cashUpdates.consumed.size + cashUpdates.produced.size) @@ -1598,7 +1598,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // CONSUMED states // count fungible assets - val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED) + val countCriteriaConsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED) val fungibleStateCountConsumed = vaultService.queryBy>(countCriteriaConsumed).otherResults.single() as Long assertThat(fungibleStateCountConsumed.toInt()).isEqualTo(cashUpdates.consumed.size) @@ -1622,7 +1622,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val start = TODAY val end = TODAY.plus(30, ChronoUnit.DAYS) val recordedBetweenExpression = TimeCondition( - QueryCriteria.TimeInstantType.RECORDED, + TimeInstantType.RECORDED, ColumnPredicate.Between(start, end)) val criteria = VaultQueryCriteria(timeCondition = recordedBetweenExpression) val results = vaultService.queryBy(criteria) @@ -1632,7 +1632,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // Future val startFuture = TODAY.plus(1, ChronoUnit.DAYS) val recordedBetweenExpressionFuture = TimeCondition( - QueryCriteria.TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end)) + TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end)) val criteriaFuture = VaultQueryCriteria(timeCondition = recordedBetweenExpressionFuture) assertThat(vaultService.queryBy(criteriaFuture).states).isEmpty() } @@ -1648,7 +1648,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { consumeCash(100.DOLLARS) val asOfDateTime = TODAY val consumedAfterExpression = TimeCondition( - QueryCriteria.TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime)) + TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime)) val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED, timeCondition = consumedAfterExpression) val results = vaultService.queryBy(criteria) @@ -1705,6 +1705,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } // pagination: invalid page size + @Suppress("INTEGER_OVERFLOW") @Test(timeout=300_000) fun `invalid page size`() { expectedEx.expect(VaultQueryException::class.java) @@ -1712,8 +1713,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { database.transaction { vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 100, DUMMY_CASH_ISSUER) - @Suppress("EXPECTED_CONDITION") - val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, @Suppress("INTEGER_OVERFLOW") Integer.MAX_VALUE + 1) // overflow = -2147483648 + val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, Integer.MAX_VALUE + 1) // overflow = -2147483648 val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) vaultService.queryBy(criteria, paging = pagingSpec) } @@ -1781,9 +1781,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties { println("$index : $any") } assertThat(results.otherResults.size).isEqualTo(402) - val instants = results.otherResults.filter { it is Instant }.map { it as Instant } + val instants = results.otherResults.filterIsInstance() assertThat(instants).isSorted - val longs = results.otherResults.filter { it is Long }.map { it as Long } + val longs = results.otherResults.filterIsInstance() assertThat(longs.size).isEqualTo(201) assertThat(instants.size).isEqualTo(201) assertThat(longs.sum()).isEqualTo(20100L) @@ -1911,8 +1911,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties { fun `LinearStateQueryCriteria returns empty resultset without errors if there is an empty list after the 'in' clause`() { database.transaction { val uid = UniqueIdentifier("999") - vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, uniqueIdentifier = uid) - vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, externalId = "1234") + vaultFiller.fillWithSomeTestLinearStates(txCount = 1, uniqueIdentifier = uid) + vaultFiller.fillWithSomeTestLinearStates(txCount = 1, externalId = "1234") val uuidCriteria = LinearStateQueryCriteria(uuid = listOf(uid.id)) val externalIdCriteria = LinearStateQueryCriteria(externalId = listOf("1234")) @@ -2061,6 +2061,26 @@ abstract class VaultQueryTestsBase : VaultQueryParties { } } + @Test(timeout = 300_000) + fun `unconsumed states which are globally unordered across multiple transactions sorted by custom attribute`() { + val linearNumbers = Array(2) { LongArray(2) } + // Make sure states from the same transaction are not given consecutive linear numbers. + linearNumbers[0][0] = 1L + linearNumbers[0][1] = 3L + linearNumbers[1][0] = 2L + linearNumbers[1][1] = 4L + + val results = database.transaction { + vaultFiller.fillWithTestStates(txCount = 2, statesPerTx = 2) { participantsToUse, txIndex, stateIndex -> + DummyLinearContract.State(participants = participantsToUse, linearNumber = linearNumbers[txIndex][stateIndex]) + } + + val sortColumn = Sort.SortColumn(SortAttribute.Custom(DummyLinearStateSchemaV1.PersistentDummyLinearState::class.java, "linearNumber")) + vaultService.queryBy(VaultQueryCriteria(), sorting = Sort(setOf(sortColumn))) + } + assertThat(results.states.map { it.state.data.linearNumber }).isEqualTo(listOf(1L, 2L, 3L, 4L)) + } + @Test(timeout=300_000) fun `return consumed linear states for a given linear id`() { database.transaction { @@ -2390,7 +2410,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { services.recordTransactions(commercialPaper2) val ccyIndex = builder { CommercialPaperSchemaV1.PersistentCommercialPaperState::currency.equal(USD.currencyCode) } - val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) + val criteria1 = VaultCustomQueryCriteria(ccyIndex) val result = vaultService.queryBy(criteria1) @@ -2433,9 +2453,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val maturityIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::maturity.greaterThanOrEqual(TEST_TX_TIME + 30.days) val faceValueIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::faceValue.greaterThanOrEqual(10000L) - val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) - val criteria2 = QueryCriteria.VaultCustomQueryCriteria(maturityIndex) - val criteria3 = QueryCriteria.VaultCustomQueryCriteria(faceValueIndex) + val criteria1 = VaultCustomQueryCriteria(ccyIndex) + val criteria2 = VaultCustomQueryCriteria(maturityIndex) + val criteria3 = VaultCustomQueryCriteria(faceValueIndex) vaultService.queryBy(criteria1.and(criteria3).and(criteria2)) } @@ -2458,8 +2478,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties { val generalCriteria = VaultQueryCriteria(Vault.StateStatus.ALL) val results = builder { - val currencyIndex = PersistentCashState::currency.equal(USD.currencyCode) - val quantityIndex = PersistentCashState::pennies.greaterThanOrEqual(10L) + val currencyIndex = CashSchemaV1.PersistentCashState::currency.equal(USD.currencyCode) + val quantityIndex = CashSchemaV1.PersistentCashState::pennies.greaterThanOrEqual(10L) val customCriteria1 = VaultCustomQueryCriteria(currencyIndex) val customCriteria2 = VaultCustomQueryCriteria(quantityIndex) @@ -2710,7 +2730,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // Enrich and override QueryCriteria with additional default attributes (such as soft locks) val enrichedCriteria = VaultQueryCriteria(contractStateTypes = setOf(DealState::class.java), // enrich - softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())), + softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())), status = Vault.StateStatus.UNCONSUMED) // override // Sorting val sortAttribute = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF) @@ -3056,7 +3076,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { assertThat(snapshot.states).hasSize(0) val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states this.session.flush() - vaultFiller.consumeLinearStates(states.toList()) + vaultFiller.consumeStates(states) updates } @@ -3079,7 +3099,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { assertThat(snapshot.states).hasSize(0) val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states this.session.flush() - vaultFiller.consumeLinearStates(states.toList()) + vaultFiller.consumeStates(states) updates } @@ -3102,7 +3122,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate { assertThat(snapshot.states).hasSize(0) val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states this.session.flush() - vaultFiller.consumeLinearStates(states.toList()) + vaultFiller.consumeStates(states) updates } diff --git a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt index 467b54ea22..f2775e1878 100644 --- a/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt +++ b/testing/test-utils/src/main/kotlin/net/corda/testing/internal/vault/VaultFiller.kt @@ -1,6 +1,20 @@ +@file:Suppress("LongParameterList") + package net.corda.testing.internal.vault -import net.corda.core.contracts.* +import net.corda.core.contracts.Amount +import net.corda.core.contracts.AttachmentConstraint +import net.corda.core.contracts.AutomaticPlaceholderConstraint +import net.corda.core.contracts.BelongsToContract +import net.corda.core.contracts.CommandAndState +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.FungibleAsset +import net.corda.core.contracts.Issued +import net.corda.core.contracts.LinearState +import net.corda.core.contracts.PartyAndReference +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.TransactionState +import net.corda.core.contracts.UniqueIdentifier import net.corda.core.crypto.Crypto import net.corda.core.crypto.SignatureMetadata import net.corda.core.identity.AbstractParty @@ -19,9 +33,7 @@ import net.corda.finance.contracts.asset.Cash import net.corda.finance.contracts.asset.Obligation import net.corda.finance.contracts.asset.OnLedgerAsset import net.corda.finance.workflows.asset.CashUtils -import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyState -import net.corda.testing.core.DummyCommandData import net.corda.testing.core.TestIdentity import net.corda.testing.core.dummyCommand import net.corda.testing.core.singleIdentity @@ -32,6 +44,7 @@ import java.time.Duration import java.time.Instant import java.time.Instant.now import java.util.* +import kotlin.math.floor /** * The service hub should provide at least a key management service and a storage service. @@ -46,7 +59,7 @@ class VaultFiller @JvmOverloads constructor( private val rngFactory: () -> Random = { Random(0L) }) { companion object { fun calculateRandomlySizedAmounts(howMuch: Amount, min: Int, max: Int, rng: Random): LongArray { - val numSlots = min + Math.floor(rng.nextDouble() * (max - min)).toInt() + val numSlots = min + floor(rng.nextDouble() * (max - min)).toInt() val baseSize = howMuch.quantity / numSlots check(baseSize > 0) { baseSize } @@ -79,31 +92,18 @@ class VaultFiller @JvmOverloads constructor( issuerServices: ServiceHub = services, participants: List = emptyList(), includeMe: Boolean = true): Vault { - val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey - val me = AnonymousParty(myKey) - val participantsToUse = if (includeMe) participants.plus(me) else participants - - val transactions: List = dealIds.map { - // Issue a deal state - val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { - addOutputState(DummyDealContract.State(ref = it, participants = participantsToUse), DUMMY_DEAL_PROGRAM_ID) - addCommand(dummyCommand()) - } - val stx = issuerServices.signInitialTransaction(dummyIssue) - return@map services.addSignature(stx, defaultNotary.publicKey) + return fillWithTestStates( + txCount = dealIds.size, + participants = participants, + includeMe = includeMe, + services = issuerServices + ) { participantsToUse, txIndex, _ -> + DummyDealContract.State(ref = dealIds[txIndex], participants = participantsToUse) } - val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE - services.recordTransactions(statesToRecord, transactions) - // Get all the StateAndRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - - return Vault(states) } @JvmOverloads - fun fillWithSomeTestLinearStates(numberToCreate: Int, + fun fillWithSomeTestLinearStates(txCount: Int, externalId: String? = null, participants: List = emptyList(), uniqueIdentifier: UniqueIdentifier? = null, @@ -113,81 +113,41 @@ class VaultFiller @JvmOverloads constructor( linearTimestamp: Instant = now(), constraint: AttachmentConstraint = AutomaticPlaceholderConstraint, includeMe: Boolean = true): Vault { - val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey - val me = AnonymousParty(myKey) - val issuerKey = defaultNotary.keyPair - val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) - val participantsToUse = if (includeMe) participants.plus(me) else participants - val transactions: List = (1..numberToCreate).map { - // Issue a Linear state - val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { - addOutputState(DummyLinearContract.State( - linearId = uniqueIdentifier ?: UniqueIdentifier(externalId), - participants = participantsToUse, - linearString = linearString, - linearNumber = linearNumber, - linearBoolean = linearBoolean, - linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID, - constraint = constraint) - addCommand(dummyCommand()) - } - return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata) + return fillWithTestStates(txCount, 1, participants, constraint, includeMe) { participantsToUse, _, _ -> + DummyLinearContract.State( + linearId = uniqueIdentifier ?: UniqueIdentifier(externalId), + participants = participantsToUse, + linearString = linearString, + linearNumber = linearNumber, + linearBoolean = linearBoolean, + linearTimestamp = linearTimestamp + ) } - val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE - services.recordTransactions(statesToRecord, transactions) - // Get all the StateAndRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - - return Vault(states) } @JvmOverloads - fun fillWithSomeTestLinearAndDealStates(numberToCreate: Int, + fun fillWithSomeTestLinearAndDealStates(txCount: Int, externalId: String? = null, participants: List = emptyList(), linearString: String = "", linearNumber: Long = 0L, linearBoolean: Boolean = false, - linearTimestamp: Instant = now()): Vault { - val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey - val me = AnonymousParty(myKey) - val issuerKey = defaultNotary.keyPair - val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) - val transactions: List = (1..numberToCreate).map { - val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { - // Issue a Linear state - addOutputState(DummyLinearContract.State( + linearTimestamp: Instant = now()): Vault { + return fillWithTestStates(txCount, 2, participants) { participantsToUse, _, stateIndex -> + when (stateIndex) { + 0 -> DummyLinearContract.State( linearId = UniqueIdentifier(externalId), - participants = participants.plus(me), + participants = participantsToUse, linearString = linearString, linearNumber = linearNumber, linearBoolean = linearBoolean, - linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID) - // Issue a Deal state - addOutputState(DummyDealContract.State(ref = "test ref", participants = participants.plus(me)), DUMMY_DEAL_PROGRAM_ID) - addCommand(dummyCommand()) + linearTimestamp = linearTimestamp + ) + else -> DummyDealContract.State(ref = "test ref", participants = participantsToUse) } - return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata) } - services.recordTransactions(transactions) - // Get all the StateAndRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - return Vault(states) } - @JvmOverloads - fun fillWithSomeTestCash(howMuch: Amount, - issuerServices: ServiceHub, - thisManyStates: Int, - issuedBy: PartyAndReference, - owner: AbstractParty? = null, - rng: Random? = null, - statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT) = fillWithSomeTestCash(howMuch, issuerServices, thisManyStates, thisManyStates, issuedBy, owner, rng, statesToRecord) - /** * Creates a random set of between (by default) 3 and 10 cash states that add up to the given amount and adds them * to the vault. This is intended for unit tests. By default the cash is owned by the legal @@ -196,14 +156,15 @@ class VaultFiller @JvmOverloads constructor( * @param issuerServices service hub of the issuer node, which will be used to sign the transaction. * @return a vault object that represents the generated states (it will NOT be the full vault from the service hub!). */ + @JvmOverloads fun fillWithSomeTestCash(howMuch: Amount, issuerServices: ServiceHub, atLeastThisManyStates: Int, - atMostThisManyStates: Int, issuedBy: PartyAndReference, owner: AbstractParty? = null, rng: Random? = null, - statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault { + statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT, + atMostThisManyStates: Int = atLeastThisManyStates): Vault { val amounts = calculateRandomlySizedAmounts(howMuch, atLeastThisManyStates, atMostThisManyStates, rng ?: rngFactory()) // We will allocate one state to one transaction, for simplicities sake. val cash = Cash() @@ -212,39 +173,46 @@ class VaultFiller @JvmOverloads constructor( cash.generateIssue(issuance, Amount(pennies, Issued(issuedBy, howMuch.token)), owner ?: services.myInfo.singleIdentity(), altNotary) return@map issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) } - services.recordTransactions(statesToRecord, transactions) - // Get all the StateRefs of all the generated transactions. - val states = transactions.flatMap { stx -> - stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } - } - - return Vault(states) + return recordTransactions(transactions, statesToRecord) } /** * Records a dummy state in the Vault (useful for creating random states when testing vault queries) */ - fun fillWithDummyState(participants: List = listOf(services.myInfo.singleIdentity())) : Vault { - val outputState = TransactionState( - data = DummyState(Random().nextInt(), participants = participants), - contract = DummyContract.PROGRAM_ID, - notary = defaultNotary.party - ) - val participantKeys : List = participants.map { it.owningKey } - val builder = TransactionBuilder() - .addOutputState(outputState) - .addCommand(DummyCommandData, participantKeys) - val stxn = services.signInitialTransaction(builder) - services.recordTransactions(stxn) - return Vault(setOf(stxn.tx.outRef(0))) + fun fillWithDummyState(participants: List = listOf(services.myInfo.singleIdentity())): Vault { + return fillWithTestStates(participants = participants) { participantsToUse, _, _ -> + DummyState(Random().nextInt(), participants = participantsToUse) + } } - /** - * Puts together an issuance transaction for the specified amount that starts out being owned by the given pubkey. - */ - fun generateCommoditiesIssue(tx: TransactionBuilder, amount: Amount>, owner: AbstractParty, notary: Party) - = OnLedgerAsset.generateIssue(tx, TransactionState(CommodityState(amount, owner), Obligation.PROGRAM_ID, notary), Obligation.Commands.Issue()) - + fun fillWithTestStates(txCount: Int = 1, + statesPerTx: Int = 1, + participants: List = emptyList(), + constraint: AttachmentConstraint = AutomaticPlaceholderConstraint, + includeMe: Boolean = true, + services: ServiceHub = this.services, + genOutputState: (participantsToUse: List, txIndex: Int, stateIndex: Int) -> T): Vault { + val issuerKey = defaultNotary.keyPair + val signatureMetadata = SignatureMetadata( + services.myInfo.platformVersion, + Crypto.findSignatureScheme(issuerKey.public).schemeNumberID + ) + val participantsToUse = if (includeMe) { + participants + AnonymousParty(this.services.myInfo.chooseIdentity().owningKey) + } else { + participants + } + val transactions = Array(txCount) { txIndex -> + val builder = TransactionBuilder(notary = defaultNotary.party) + repeat(statesPerTx) { stateIndex -> + builder.addOutputState(genOutputState(participantsToUse, txIndex, stateIndex), constraint) + } + builder.addCommand(dummyCommand()) + services.signInitialTransaction(builder).withAdditionalSignature(issuerKey, signatureMetadata) + } + val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE + return recordTransactions(transactions.asList(), statesToRecord) + } /** * @@ -257,13 +225,16 @@ class VaultFiller @JvmOverloads constructor( val me = AnonymousParty(myKey) val issuance = TransactionBuilder(null as Party?) - generateCommoditiesIssue(issuance, Amount(amount.quantity, Issued(issuedBy, amount.token)), me, altNotary) + OnLedgerAsset.generateIssue( + issuance, + TransactionState(CommodityState(Amount(amount.quantity, Issued(issuedBy, amount.token)), me), Obligation.PROGRAM_ID, altNotary), + Obligation.Commands.Issue() + ) val transaction = issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) - services.recordTransactions(transaction) - return Vault(setOf(transaction.tx.outRef(0))) + return recordTransactions(listOf(transaction)) } - private fun consume(states: List>) { + fun consumeStates(states: Iterable>) { // Create a txn consuming different contract types states.forEach { val builder = TransactionBuilder(notary = altNotary).apply { @@ -300,10 +271,11 @@ class VaultFiller @JvmOverloads constructor( } } - fun consumeDeals(dealStates: List>) = consume(dealStates) - fun consumeLinearStates(linearStates: List>) = consume(linearStates) + fun consumeDeals(dealStates: List>) = consumeStates(dealStates) + fun consumeLinearStates(linearStates: List>) = consumeStates(linearStates) fun evolveLinearStates(linearStates: List>) = consumeAndProduce(linearStates) fun evolveLinearState(linearState: StateAndRef): StateAndRef = consumeAndProduce(linearState) + /** * Consume cash, sending any change to the default identity for this node. Only suitable for use in test scenarios, * where nodes have a default identity. @@ -319,6 +291,16 @@ class VaultFiller @JvmOverloads constructor( services.recordTransactions(spendTx) return update.getOrThrow(Duration.ofSeconds(3)) } + + private fun recordTransactions(transactions: Iterable, + statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault { + services.recordTransactions(statesToRecord, transactions) + // Get all the StateAndRefs of all the generated transactions. + val states = transactions.flatMap { stx -> + stx.tx.outputs.indices.map { i -> stx.tx.outRef(i) } + } + return Vault(states) + } } @@ -344,4 +326,3 @@ data class CommodityState( override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Obligation.Commands.Move(), copy(owner = newOwner)) } - From a817218b08775736919e70f11ecf5e1673fd9d27 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Thu, 1 Jun 2023 15:51:58 +0100 Subject: [PATCH 04/12] ENT-9806: Added peer info to SSL handshake logging, and other changes for ENT merge (#7380) --- .../nodeapi/internal/ArtemisTcpTransport.kt | 2 +- .../internal/protonwrapper/netty/SSLHelper.kt | 4 +++- .../services/messaging/SimpleAMQPClient.kt | 12 ++++------- .../messaging/NodeNettyAcceptorFactory.kt | 21 ++++++++++++------- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt index 686c36488d..3b5adaf934 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt @@ -85,7 +85,7 @@ class ArtemisTcpTransport { fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, - trustManagerFactory: TrustManagerFactory?, + trustManagerFactory: TrustManagerFactory? = config?.trustStore?.get()?.let(::trustManagerFactory), enableSSL: Boolean = true, threadPoolName: String = "P2PServer", trace: Boolean = false, diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt index 7de3e5e302..dc207f2c7b 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/SSLHelper.kt @@ -305,9 +305,11 @@ internal fun splitKeystore(config: AMQPConfiguration): Map logger.info("SSL handshake completed in ${duration}ms with ${engine().session.peerPrincipal}") - it.isCancelled -> logger.warn("SSL handshake cancelled after ${duration}ms") + it.isSuccess -> logger.info("SSL handshake completed in ${duration}ms with $peer") + it.isCancelled -> logger.warn("SSL handshake cancelled after ${duration}ms with $peer") else -> when (it.cause()) { - is ClosedChannelException -> logger.warn("SSL handshake closed early after ${duration}ms") - is SslHandshakeTimeoutException -> logger.warn("SSL handshake timed out after ${duration}ms") - else -> logger.warn("SSL handshake failed after ${duration}ms", it.cause()) + is ClosedChannelException -> logger.warn("SSL handshake closed early after ${duration}ms with $peer") + is SslHandshakeTimeoutException -> logger.warn("SSL handshake timed out after ${duration}ms with $peer") + else -> logger.warn("SSL handshake failed after ${duration}ms with $peer", it.cause()) } } } From 4dcd9245d3f084cd13e544fea7fc31b45ae79ef9 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Thu, 1 Jun 2023 17:33:04 +0100 Subject: [PATCH 05/12] ENT-9806: Using Artemis SSLContextFactory service to pass in custom TrustManagerFactory. This removes the need to copy code from NettyAcceptor. --- .../nodeapi/internal/ArtemisTcpTransport.kt | 11 +- .../CertificateRevocationListNodeTests.kt | 4 +- .../internal/artemis/BrokerJaasLoginModule.kt | 2 +- .../messaging/ArtemisMessagingServer.kt | 1 - .../messaging/NodeNettyAcceptorFactory.kt | 133 ++---------------- .../messaging/NodeSSLContextFactory.kt | 59 ++++++++ ...pi.core.remoting.ssl.OpenSSLContextFactory | 1 + ...is.spi.core.remoting.ssl.SSLContextFactory | 1 + 8 files changed, 80 insertions(+), 132 deletions(-) create mode 100644 node/src/main/kotlin/net/corda/node/services/messaging/NodeSSLContextFactory.kt create mode 100644 node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactory create mode 100644 node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactory diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt index cc4f1b98b9..24813ffaa8 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt @@ -183,10 +183,7 @@ class ArtemisTcpTransport { options[TransportConstants.HANDSHAKE_TIMEOUT] = 0 if (trustManagerFactory != null) { // NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use - // more customised instances which use our revocation checkers, which we pass directly into NodeNettyAcceptorFactory. - // - // This, however, requires copying a lot of code from NettyAcceptor into NodeNettyAcceptor. The version of Artemis in - // Corda 4.9 solves this problem by introducing a "trustManagerFactoryPlugin" config option. + // more customised instances which use our revocation checkers, so we pass them in, to be picked up by Node(Open)SSLContextFactory. options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory } return createTransport( @@ -208,6 +205,10 @@ class ArtemisTcpTransport { threadPoolName: String, trace: Boolean, remotingThreads: Int?): TransportConfiguration { + if (enableSSL) { + // This is required to stop Client checking URL address vs. Server provided certificate + options[TransportConstants.VERIFY_HOST_PROP_NAME] = false + } return createTransport( NodeNettyConnectorFactory::class.java.name, hostAndPort, @@ -232,8 +233,6 @@ class ArtemisTcpTransport { if (enableSSL) { options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",") options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",") - // This is required to stop Client checking URL address vs. Server provided certificate - options[TransportConstants.VERIFY_HOST_PROP_NAME] = false } // By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357) options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1 diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt index 1c21ce590c..b460cc2268 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt @@ -497,7 +497,9 @@ class ArtemisServerRevocationTest : AbstractServerRevocationTest() { } val queueName = "${P2P_PREFIX}Test" - artemisNode.client.started!!.session.createQueue(queueName, RoutingType.ANYCAST, queueName, true) + artemisNode.client.started!!.session.createQueue( + QueueConfiguration(queueName).setRoutingType(RoutingType.ANYCAST).setAddress(queueName).setDurable(true) + ) val clientConnectionChangeStatus = client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus) diff --git a/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt b/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt index cbfce17eb1..4038a0f2ef 100644 --- a/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt +++ b/node/src/main/kotlin/net/corda/node/internal/artemis/BrokerJaasLoginModule.kt @@ -140,7 +140,7 @@ class BrokerJaasLoginModule : BaseBrokerJaasLoginModule() { // This check is redundant as it was performed already during the SSL handshake CertificateChainCheckPolicy.RootMustMatch .createCheck(p2pJaasConfig.keyStore, p2pJaasConfig.trustStore) - .checkCertificateChain(certificates!!) + .checkCertificateChain(certificates) Pair(certificates.first().subjectDN.name, listOf(RolePrincipal(PEER_ROLE))) } else -> { diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt index 5849b2484f..f03242d5e0 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt @@ -42,7 +42,6 @@ import org.apache.activemq.artemis.core.security.Role import org.apache.activemq.artemis.core.server.ActiveMQServer import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager -import java.io.IOException import java.lang.Long.max import javax.annotation.concurrent.ThreadSafe import javax.security.auth.login.AppConfigurationEntry diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt index 5cc3096549..ddca20e266 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt @@ -5,24 +5,14 @@ import io.netty.channel.ChannelHandlerContext import io.netty.channel.group.ChannelGroup import io.netty.handler.logging.LogLevel import io.netty.handler.logging.LoggingHandler -import io.netty.handler.ssl.SslContext -import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.ssl.SslHandler import io.netty.handler.ssl.SslHandshakeTimeoutException -import io.netty.handler.ssl.SslProvider import net.corda.core.internal.declaredField import net.corda.core.utilities.contextLogger import net.corda.nodeapi.internal.ArtemisTcpTransport -import net.corda.nodeapi.internal.config.CertificateStore -import net.corda.nodeapi.internal.protonwrapper.netty.createAndInitSslContext -import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor -import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration import org.apache.activemq.artemis.api.core.BaseInterceptor import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor -import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants -import org.apache.activemq.artemis.core.remoting.impl.ssl.SSLSupport -import org.apache.activemq.artemis.core.server.ActiveMQServerLogger import org.apache.activemq.artemis.core.server.balancing.RedirectHandler import org.apache.activemq.artemis.core.server.cluster.ClusterConnection import org.apache.activemq.artemis.spi.core.protocol.ProtocolManager @@ -30,24 +20,20 @@ import org.apache.activemq.artemis.spi.core.remoting.Acceptor import org.apache.activemq.artemis.spi.core.remoting.AcceptorFactory import org.apache.activemq.artemis.spi.core.remoting.BufferHandler import org.apache.activemq.artemis.spi.core.remoting.ServerConnectionLifeCycleListener +import org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactoryProvider +import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactoryProvider import org.apache.activemq.artemis.utils.ConfigurationHelper import org.apache.activemq.artemis.utils.actors.OrderedExecutor import java.net.SocketAddress import java.nio.channels.ClosedChannelException -import java.nio.file.Paths -import java.security.PrivilegedExceptionAction import java.time.Duration import java.util.concurrent.Executor import java.util.concurrent.ScheduledExecutorService import java.util.regex.Pattern -import javax.net.ssl.KeyManagerFactory -import javax.net.ssl.SSLContext import javax.net.ssl.SSLEngine import javax.net.ssl.SSLPeerUnverifiedException -import javax.net.ssl.TrustManagerFactory -import javax.security.auth.Subject -@Suppress("unused", "TooGenericExceptionCaught", "ComplexMethod", "MagicNumber", "TooManyFunctions") +@Suppress("unused") // Used via reflection in ArtemisTcpTransport class NodeNettyAcceptorFactory : AcceptorFactory { override fun createAcceptor(name: String?, clusterConnection: ClusterConnection?, @@ -74,6 +60,12 @@ class NodeNettyAcceptorFactory : AcceptorFactory { { companion object { private val defaultThreadPoolNamePattern = Pattern.compile("""Thread-(\d+) \(activemq-netty-threads\)""") + + init { + // Make sure Artemis isn't using another (Open)SSLContextFactory + check(SSLContextFactoryProvider.getSSLContextFactory() is NodeSSLContextFactory) + check(OpenSSLContextFactoryProvider.getOpenSSLContextFactory() is NodeOpenSSLContextFactory) + } } private val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "NodeNettyAcceptor", configuration) @@ -100,7 +92,7 @@ class NodeNettyAcceptorFactory : AcceptorFactory { @Synchronized override fun getSslHandler(alloc: ByteBufAllocator?, peerHost: String?, peerPort: Int): SslHandler { applyThreadPoolName() - val engine = getSSLEngine(alloc, peerHost, peerPort) + val engine = super.getSslHandler(alloc, peerHost, peerPort).engine() val sslHandler = NodeAcceptorSslHandler(engine, sslDelegatedTaskExecutor, trace) val handshakeTimeout = configuration[ArtemisTcpTransport.SSL_HANDSHAKE_TIMEOUT_NAME] as Duration? if (handshakeTimeout != null) { @@ -118,111 +110,6 @@ class NodeNettyAcceptorFactory : AcceptorFactory { Thread.currentThread().name = "$threadPoolName-${matcher.group(1)}" // Preserve the pool thread number } } - - /** - * This is a copy of [NettyAcceptor.getSslHandler] so that we can provide different implementations for [loadOpenSslEngine] and - * [loadJdkSslEngine]. [NodeNettyAcceptor], instead of creating a default [TrustManagerFactory], will simply use the provided one in - * the [ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] configuration. - */ - private fun getSSLEngine(alloc: ByteBufAllocator?): SSLEngine { - val engine = if (sslProvider == TransportConstants.OPENSSL_PROVIDER) { - loadOpenSslEngine(alloc) - } else { - loadJdkSslEngine() - } - engine.useClientMode = false - if (needClientAuth) { - engine.needClientAuth = true - } - - // setting the enabled cipher suites resets the enabled protocols so we need - // to save the enabled protocols so that after the customer cipher suite is enabled - // we can reset the enabled protocols if a customer protocol isn't specified - val originalProtocols = engine.enabledProtocols - if (enabledCipherSuites != null) { - try { - engine.enabledCipherSuites = SSLSupport.parseCommaSeparatedListIntoArray(enabledCipherSuites) - } catch (e: IllegalArgumentException) { - ActiveMQServerLogger.LOGGER.invalidCipherSuite(SSLSupport.parseArrayIntoCommandSeparatedList(engine.supportedCipherSuites)) - throw e - } - } - if (enabledProtocols != null) { - try { - engine.enabledProtocols = SSLSupport.parseCommaSeparatedListIntoArray(enabledProtocols) - } catch (e: IllegalArgumentException) { - ActiveMQServerLogger.LOGGER.invalidProtocol(SSLSupport.parseArrayIntoCommandSeparatedList(engine.supportedProtocols)) - throw e - } - } else { - engine.enabledProtocols = originalProtocols - } - return engine - } - - /** - * Copy of [NettyAcceptor.loadOpenSslEngine] which invokes our custom [createOpenSslContext]. - */ - private fun loadOpenSslEngine(alloc: ByteBufAllocator?): SSLEngine { - val context = try { - // We copied all this code just so we could replace the SSLSupport.createNettyContext method call with our own one. - createOpenSslContext() - } catch (e: Exception) { - throw IllegalStateException("Unable to create NodeNettyAcceptor", e) - } - return Subject.doAs(null, PrivilegedExceptionAction { - context.newEngine(alloc) - }) - } - - /** - * Copy of [NettyAcceptor.loadJdkSslEngine] which invokes our custom [createJdkSSLContext]. - */ - private fun loadJdkSslEngine(): SSLEngine { - val context = try { - // We copied all this code just so we could replace the SSLHelper.createContext method call with our own one. - createJdkSSLContext() - } catch (e: Exception) { - throw IllegalStateException("Unable to create NodeNettyAcceptor", e) - } - return Subject.doAs(null, PrivilegedExceptionAction { - context.createSSLEngine() - }) - } - - /** - * Create an [SSLContext] using the [TrustManagerFactory] provided on the [ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] config. - */ - private fun createJdkSSLContext(): SSLContext { - return createAndInitSslContext( - createKeyManagerFactory(), - configuration[ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory? - ) - } - - /** - * Create an [SslContext] using the the [TrustManagerFactory] provided on the [ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] config. - */ - private fun createOpenSslContext(): SslContext { - return SslContextBuilder - .forServer(createKeyManagerFactory()) - .sslProvider(SslProvider.OPENSSL) - .trustManager(configuration[ArtemisTcpTransport.TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?) - .build() - } - - private fun createKeyManagerFactory(): KeyManagerFactory { - return keyManagerFactory(CertificateStore.fromFile(Paths.get(keyStorePath), keyStorePassword, keyStorePassword, false)) - } - - // Replicate the fields which are private in NettyAcceptor - private val sslProvider = ConfigurationHelper.getStringProperty(TransportConstants.SSL_PROVIDER, TransportConstants.DEFAULT_SSL_PROVIDER, configuration) - private val needClientAuth = ConfigurationHelper.getBooleanProperty(TransportConstants.NEED_CLIENT_AUTH_PROP_NAME, TransportConstants.DEFAULT_NEED_CLIENT_AUTH, configuration) - private val enabledCipherSuites = ConfigurationHelper.getStringProperty(TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME, TransportConstants.DEFAULT_ENABLED_CIPHER_SUITES, configuration) - private val enabledProtocols = ConfigurationHelper.getStringProperty(TransportConstants.ENABLED_PROTOCOLS_PROP_NAME, TransportConstants.DEFAULT_ENABLED_PROTOCOLS, configuration) - private val keyStorePath = ConfigurationHelper.getStringProperty(TransportConstants.KEYSTORE_PATH_PROP_NAME, TransportConstants.DEFAULT_KEYSTORE_PATH, configuration) - private val keyStoreProvider = ConfigurationHelper.getStringProperty(TransportConstants.KEYSTORE_PROVIDER_PROP_NAME, TransportConstants.DEFAULT_KEYSTORE_PROVIDER, configuration) - private val keyStorePassword = ConfigurationHelper.getPasswordProperty(TransportConstants.KEYSTORE_PASSWORD_PROP_NAME, TransportConstants.DEFAULT_KEYSTORE_PASSWORD, configuration, ActiveMQDefaultConfiguration.getPropMaskPassword(), ActiveMQDefaultConfiguration.getPropPasswordCodec()) } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeSSLContextFactory.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeSSLContextFactory.kt new file mode 100644 index 0000000000..38f60d7a57 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeSSLContextFactory.kt @@ -0,0 +1,59 @@ +package net.corda.node.services.messaging + +import io.netty.handler.ssl.SslContext +import io.netty.handler.ssl.SslContextBuilder +import io.netty.handler.ssl.SslProvider +import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.TRUST_MANAGER_FACTORY_NAME +import net.corda.nodeapi.internal.config.CertificateStore +import net.corda.nodeapi.internal.protonwrapper.netty.createAndInitSslContext +import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory +import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultOpenSSLContextFactory +import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultSSLContextFactory +import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextConfig +import java.nio.file.Paths +import javax.net.ssl.KeyManagerFactory +import javax.net.ssl.SSLContext +import javax.net.ssl.TrustManagerFactory + +class NodeSSLContextFactory : DefaultSSLContextFactory() { + override fun getSSLContext(config: SSLContextConfig, additionalOpts: Map): SSLContext { + val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory? + return if (trustManagerFactory != null) { + createAndInitSslContext(loadKeyManagerFactory(config), trustManagerFactory) + } else { + super.getSSLContext(config, additionalOpts) + } + } + + override fun getPriority(): Int { + // We make sure this factory is the one that's chosen, so any sufficiently large value will do. + return 15 + } +} + + +class NodeOpenSSLContextFactory : DefaultOpenSSLContextFactory() { + override fun getServerSslContext(config: SSLContextConfig, additionalOpts: Map): SslContext { + val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory? + return if (trustManagerFactory != null) { + SslContextBuilder + .forServer(loadKeyManagerFactory(config)) + .sslProvider(SslProvider.OPENSSL) + .trustManager(trustManagerFactory) + .build() + } else { + super.getServerSslContext(config, additionalOpts) + } + } + + override fun getPriority(): Int { + // We make sure this factory is the one that's chosen, so any sufficiently large value will do. + return 15 + } +} + + +private fun loadKeyManagerFactory(config: SSLContextConfig): KeyManagerFactory { + val keyStore = CertificateStore.fromFile(Paths.get(config.keystorePath), config.keystorePassword, config.keystorePassword, false) + return keyManagerFactory(keyStore) +} diff --git a/node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactory b/node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactory new file mode 100644 index 0000000000..6b69d9d3ff --- /dev/null +++ b/node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactory @@ -0,0 +1 @@ +net.corda.node.services.messaging.NodeOpenSSLContextFactory diff --git a/node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactory b/node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactory new file mode 100644 index 0000000000..59e57dca26 --- /dev/null +++ b/node/src/main/resources/META-INF/services/org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactory @@ -0,0 +1 @@ +net.corda.node.services.messaging.NodeSSLContextFactory From 0bfce451ea50b7ada2be8fdcbe3a171394c67119 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Tue, 6 Jun 2023 16:16:59 +0100 Subject: [PATCH 06/12] ENT-10013: Vault service refactoring backport --- .../corda/core/node/services/VaultService.kt | 25 +- .../net/corda/node/internal/AbstractNode.kt | 2 +- .../internal/NodeServicesForResolution.kt | 15 ++ .../internal/ServicesForResolutionImpl.kt | 20 +- .../node/migration/VaultStateMigration.kt | 5 +- .../PersistentScheduledFlowRepository.kt | 7 +- .../PersistentUniquenessProvider.kt | 9 +- .../node/services/vault/NodeVaultService.kt | 252 +++++++++++------- .../corda/node/services/vault/VaultSchema.kt | 18 ++ .../bftsmart/BFTSmartNotaryService.kt | 10 +- .../corda/notary/jpa/JPAUniquenessProvider.kt | 11 +- .../persistence/HibernateConfigurationTest.kt | 22 +- .../node/services/vault/VaultQueryTests.kt | 4 +- .../vault/VaultSoftLockManagerTest.kt | 4 +- .../net/corda/testing/node/MockServices.kt | 10 +- 15 files changed, 253 insertions(+), 161 deletions(-) create mode 100644 node/src/main/kotlin/net/corda/node/internal/NodeServicesForResolution.kt diff --git a/core/src/main/kotlin/net/corda/core/node/services/VaultService.kt b/core/src/main/kotlin/net/corda/core/node/services/VaultService.kt index 3b6db58dcd..51d61cc214 100644 --- a/core/src/main/kotlin/net/corda/core/node/services/VaultService.kt +++ b/core/src/main/kotlin/net/corda/core/node/services/VaultService.kt @@ -1,3 +1,5 @@ +@file:Suppress("LongParameterList") + package net.corda.core.node.services import co.paralleluniverse.fibers.Suspendable @@ -197,8 +199,7 @@ class Vault(val states: Iterable>) { * 4) Status types used in this query: [StateStatus.UNCONSUMED], [StateStatus.CONSUMED], [StateStatus.ALL]. * 5) Other results as a [List] of any type (eg. aggregate function results with/without group by). * - * Note: currently otherResults are used only for Aggregate Functions (in which case, the states and statesMetadata - * results will be empty). + * Note: currently [otherResults] is used only for aggregate functions (in which case, [states] and [statesMetadata] will be empty). */ @CordaSerializable data class Page(val states: List>, @@ -213,11 +214,11 @@ class Vault(val states: Iterable>) { val contractStateClassName: String, val recordedTime: Instant, val consumedTime: Instant?, - val status: Vault.StateStatus, + val status: StateStatus, val notary: AbstractParty?, val lockId: String?, val lockUpdateTime: Instant?, - val relevancyStatus: Vault.RelevancyStatus? = null, + val relevancyStatus: RelevancyStatus? = null, val constraintInfo: ConstraintInfo? = null ) { fun copy( @@ -225,7 +226,7 @@ class Vault(val states: Iterable>) { contractStateClassName: String = this.contractStateClassName, recordedTime: Instant = this.recordedTime, consumedTime: Instant? = this.consumedTime, - status: Vault.StateStatus = this.status, + status: StateStatus = this.status, notary: AbstractParty? = this.notary, lockId: String? = this.lockId, lockUpdateTime: Instant? = this.lockUpdateTime @@ -237,11 +238,11 @@ class Vault(val states: Iterable>) { contractStateClassName: String = this.contractStateClassName, recordedTime: Instant = this.recordedTime, consumedTime: Instant? = this.consumedTime, - status: Vault.StateStatus = this.status, + status: StateStatus = this.status, notary: AbstractParty? = this.notary, lockId: String? = this.lockId, lockUpdateTime: Instant? = this.lockUpdateTime, - relevancyStatus: Vault.RelevancyStatus? + relevancyStatus: RelevancyStatus? ): StateMetadata { return StateMetadata(ref, contractStateClassName, recordedTime, consumedTime, status, notary, lockId, lockUpdateTime, relevancyStatus, ConstraintInfo(AlwaysAcceptAttachmentConstraint)) } @@ -249,9 +250,9 @@ class Vault(val states: Iterable>) { companion object { @Deprecated("No longer used. The vault does not emit empty updates") - val NoUpdate = Update(emptySet(), emptySet(), type = Vault.UpdateType.GENERAL, references = emptySet()) + val NoUpdate = Update(emptySet(), emptySet(), type = UpdateType.GENERAL, references = emptySet()) @Deprecated("No longer used. The vault does not emit empty updates") - val NoNotaryUpdate = Vault.Update(emptySet(), emptySet(), type = Vault.UpdateType.NOTARY_CHANGE, references = emptySet()) + val NoNotaryUpdate = Update(emptySet(), emptySet(), type = UpdateType.NOTARY_CHANGE, references = emptySet()) } } @@ -302,7 +303,7 @@ interface VaultService { fun whenConsumed(ref: StateRef): CordaFuture> { val query = QueryCriteria.VaultQueryCriteria( stateRefs = listOf(ref), - status = Vault.StateStatus.CONSUMED + status = StateStatus.CONSUMED ) val result = trackBy(query) val snapshot = result.snapshot.states @@ -358,8 +359,8 @@ interface VaultService { /** * Helper function to determine spendable states and soft locking them. * Currently performance will be worse than for the hand optimised version in - * [Cash.unconsumedCashStatesForSpending]. However, this is fully generic and can operate with custom [FungibleState] - * and [FungibleAsset] states. + * [net.corda.finance.workflows.asset.selection.AbstractCashSelection.unconsumedCashStatesForSpending]. However, this is fully generic + * and can operate with custom [FungibleState] and [FungibleAsset] states. * @param lockId The [FlowLogic.runId]'s [UUID] of the current flow used to soft lock the states. * @param eligibleStatesQuery A custom query object that selects down to the appropriate subset of all states of the * [contractStateType]. e.g. by selecting on account, issuer, etc. The query is internally augmented with the diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 7be9b94dc9..d4a940bb84 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -1077,7 +1077,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, networkParameters: NetworkParameters) protected open fun makeVaultService(keyManagementService: KeyManagementService, - services: ServicesForResolution, + services: NodeServicesForResolution, database: CordaPersistence, cordappLoader: CordappLoader): VaultServiceInternal { return NodeVaultService(platformClock, keyManagementService, services, database, schemaService, cordappLoader.appClassLoader) diff --git a/node/src/main/kotlin/net/corda/node/internal/NodeServicesForResolution.kt b/node/src/main/kotlin/net/corda/node/internal/NodeServicesForResolution.kt new file mode 100644 index 0000000000..5baa528297 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/internal/NodeServicesForResolution.kt @@ -0,0 +1,15 @@ +package net.corda.node.internal + +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TransactionResolutionException +import net.corda.core.node.ServicesForResolution +import java.util.LinkedHashSet + +interface NodeServicesForResolution : ServicesForResolution { + @Throws(TransactionResolutionException::class) + override fun loadStates(stateRefs: Set): Set> = loadStates(stateRefs, LinkedHashSet()) + + fun >> loadStates(input: Iterable, output: C): C +} diff --git a/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt b/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt index 06e46992d4..ffb21894c1 100644 --- a/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/ServicesForResolutionImpl.kt @@ -1,11 +1,18 @@ package net.corda.node.internal -import net.corda.core.contracts.* +import net.corda.core.contracts.Attachment +import net.corda.core.contracts.AttachmentResolutionException +import net.corda.core.contracts.ContractAttachment +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TransactionResolutionException +import net.corda.core.contracts.TransactionState import net.corda.core.cordapp.CordappProvider import net.corda.core.crypto.SecureHash import net.corda.core.internal.SerializedStateAndRef +import net.corda.core.internal.uncheckedCast import net.corda.core.node.NetworkParameters -import net.corda.core.node.ServicesForResolution import net.corda.core.node.services.AttachmentStorage import net.corda.core.node.services.IdentityService import net.corda.core.node.services.NetworkParametersService @@ -23,7 +30,7 @@ data class ServicesForResolutionImpl( override val cordappProvider: CordappProvider, override val networkParametersService: NetworkParametersService, private val validatedTransactions: TransactionStorage -) : ServicesForResolution { +) : NodeServicesForResolution { override val networkParameters: NetworkParameters get() = networkParametersService.lookup(networkParametersService.currentHash) ?: throw IllegalArgumentException("No current parameters in network parameters storage") @@ -32,12 +39,11 @@ data class ServicesForResolutionImpl( return toBaseTransaction(stateRef.txhash).outputs[stateRef.index] } - @Throws(TransactionResolutionException::class) - override fun loadStates(stateRefs: Set): Set> { + override fun >> loadStates(input: Iterable, output: C): C { val baseTxs = HashMap() - return stateRefs.mapTo(LinkedHashSet()) { stateRef -> + return input.mapTo(output) { stateRef -> val baseTx = baseTxs.computeIfAbsent(stateRef.txhash, ::toBaseTransaction) - StateAndRef(baseTx.outputs[stateRef.index], stateRef) + StateAndRef(uncheckedCast(baseTx.outputs[stateRef.index]), stateRef) } } diff --git a/node/src/main/kotlin/net/corda/node/migration/VaultStateMigration.kt b/node/src/main/kotlin/net/corda/node/migration/VaultStateMigration.kt index f33418c28e..b765685910 100644 --- a/node/src/main/kotlin/net/corda/node/migration/VaultStateMigration.kt +++ b/node/src/main/kotlin/net/corda/node/migration/VaultStateMigration.kt @@ -2,7 +2,6 @@ package net.corda.node.migration import liquibase.database.Database import net.corda.core.contracts.* -import net.corda.core.crypto.SecureHash import net.corda.core.identity.CordaX500Name import net.corda.core.node.services.Vault import net.corda.core.schemas.MappedSchema @@ -18,6 +17,7 @@ import net.corda.node.services.persistence.DBTransactionStorage import net.corda.node.services.persistence.NodeAttachmentService import net.corda.node.services.vault.NodeVaultService import net.corda.node.services.vault.VaultSchemaV1 +import net.corda.node.services.vault.toStateRef import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.nodeapi.internal.persistence.SchemaMigration @@ -61,8 +61,7 @@ class VaultStateMigration : CordaMigration() { private fun getStateAndRef(persistentState: VaultSchemaV1.VaultStates): StateAndRef { val persistentStateRef = persistentState.stateRef ?: throw VaultStateMigrationException("Persistent state ref missing from state") - val txHash = SecureHash.create(persistentStateRef.txId) - val stateRef = StateRef(txHash, persistentStateRef.index) + val stateRef = persistentStateRef.toStateRef() val state = try { servicesForResolution.loadState(stateRef) } catch (e: Exception) { diff --git a/node/src/main/kotlin/net/corda/node/services/events/PersistentScheduledFlowRepository.kt b/node/src/main/kotlin/net/corda/node/services/events/PersistentScheduledFlowRepository.kt index 2208eef88f..f62db2eee4 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/PersistentScheduledFlowRepository.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/PersistentScheduledFlowRepository.kt @@ -2,8 +2,8 @@ package net.corda.node.services.events import net.corda.core.contracts.ScheduledStateRef import net.corda.core.contracts.StateRef -import net.corda.core.crypto.SecureHash import net.corda.core.schemas.PersistentStateRef +import net.corda.node.services.vault.toStateRef import net.corda.nodeapi.internal.persistence.CordaPersistence interface ScheduledFlowRepository { @@ -25,9 +25,8 @@ class PersistentScheduledFlowRepository(val database: CordaPersistence) : Schedu } private fun fromPersistentEntity(scheduledStateRecord: NodeSchedulerService.PersistentScheduledState): Pair { - val txId = scheduledStateRecord.output.txId - val index = scheduledStateRecord.output.index - return Pair(StateRef(SecureHash.create(txId), index), ScheduledStateRef(StateRef(SecureHash.create(txId), index), scheduledStateRecord.scheduledAt)) + val stateRef = scheduledStateRecord.output.toStateRef() + return Pair(stateRef, ScheduledStateRef(stateRef, scheduledStateRecord.scheduledAt)) } override fun delete(key: StateRef): Boolean { diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt b/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt index 66ec2007fa..aa69d50db3 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt @@ -25,6 +25,7 @@ import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.serialize import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug +import net.corda.node.services.vault.toStateRef import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX @@ -157,13 +158,7 @@ class PersistentUniquenessProvider(val clock: Clock, val database: CordaPersiste toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) }, fromPersistentEntity = { //TODO null check will become obsolete after making DB/JPA columns not nullable - val txId = it.id.txId - val index = it.id.index - Pair( - StateRef(txhash = SecureHash.create(txId), index = index), - SecureHash.create(it.consumingTxHash) - ) - + Pair(it.id.toStateRef(), SecureHash.create(it.consumingTxHash)) }, toPersistentEntity = { (txHash, index): StateRef, id: SecureHash -> CommittedState( diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index 6db962cdce..ac0913604c 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -3,28 +3,65 @@ package net.corda.node.services.vault import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.strands.Strand import net.corda.core.CordaRuntimeException -import net.corda.core.contracts.* +import net.corda.core.contracts.Amount +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.FungibleAsset +import net.corda.core.contracts.FungibleState +import net.corda.core.contracts.Issued +import net.corda.core.contracts.OwnableState +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TransactionState import net.corda.core.crypto.SecureHash import net.corda.core.crypto.containsAny import net.corda.core.flows.HospitalizeFlowException -import net.corda.core.internal.* +import net.corda.core.internal.ThreadBox +import net.corda.core.internal.TransactionDeserialisationException +import net.corda.core.internal.VisibleForTesting +import net.corda.core.internal.bufferUntilSubscribed +import net.corda.core.internal.tee +import net.corda.core.internal.uncheckedCast import net.corda.core.messaging.DataFeed -import net.corda.core.node.ServicesForResolution import net.corda.core.node.StatesToRecord -import net.corda.core.node.services.* -import net.corda.core.node.services.Vault.ConstraintInfo.Companion.constraintInfo -import net.corda.core.node.services.vault.* +import net.corda.core.node.services.KeyManagementService +import net.corda.core.node.services.StatesNotAvailableException +import net.corda.core.node.services.Vault +import net.corda.core.node.services.VaultQueryException +import net.corda.core.node.services.VaultService +import net.corda.core.node.services.queryBy +import net.corda.core.node.services.vault.DEFAULT_PAGE_NUM +import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE +import net.corda.core.node.services.vault.PageSpecification +import net.corda.core.node.services.vault.QueryCriteria +import net.corda.core.node.services.vault.Sort +import net.corda.core.node.services.vault.SortAttribute +import net.corda.core.node.services.vault.builder import net.corda.core.observable.internal.OnResilientSubscribe import net.corda.core.schemas.PersistentStateRef import net.corda.core.serialization.SingletonSerializeAsToken -import net.corda.core.transactions.* -import net.corda.core.utilities.* +import net.corda.core.transactions.ContractUpgradeWireTransaction +import net.corda.core.transactions.CoreTransaction +import net.corda.core.transactions.FullTransaction +import net.corda.core.transactions.LedgerTransaction +import net.corda.core.transactions.NotaryChangeWireTransaction +import net.corda.core.transactions.WireTransaction +import net.corda.core.utilities.NonEmptySet +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.debug +import net.corda.core.utilities.toNonEmptySet +import net.corda.core.utilities.trace +import net.corda.node.internal.NodeServicesForResolution import net.corda.node.services.api.SchemaService import net.corda.node.services.api.VaultServiceInternal import net.corda.node.services.schema.PersistentStateService import net.corda.node.services.statemachine.FlowStateMachineImpl -import net.corda.nodeapi.internal.persistence.* +import net.corda.nodeapi.internal.persistence.CordaPersistence +import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit +import net.corda.nodeapi.internal.persistence.contextTransactionOrNull +import net.corda.nodeapi.internal.persistence.currentDBSession +import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction import org.hibernate.Session +import org.hibernate.query.Query import rx.Observable import rx.exceptions.OnErrorNotImplementedException import rx.subjects.PublishSubject @@ -32,9 +69,11 @@ import java.security.PublicKey import java.sql.SQLException import java.time.Clock import java.time.Instant -import java.util.* +import java.util.Arrays +import java.util.UUID import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.CopyOnWriteArraySet +import java.util.stream.Stream import javax.persistence.PersistenceException import javax.persistence.Tuple import javax.persistence.criteria.CriteriaBuilder @@ -54,9 +93,9 @@ import javax.persistence.criteria.Root class NodeVaultService( private val clock: Clock, private val keyManagementService: KeyManagementService, - private val servicesForResolution: ServicesForResolution, + private val servicesForResolution: NodeServicesForResolution, private val database: CordaPersistence, - private val schemaService: SchemaService, + schemaService: SchemaService, private val appClassloader: ClassLoader ) : SingletonSerializeAsToken(), VaultServiceInternal { companion object { @@ -196,7 +235,7 @@ class NodeVaultService( if (lockId != null) { lockId = null lockUpdateTime = clock.instant() - log.trace("Releasing soft lock on consumed state: $stateRef") + log.trace { "Releasing soft lock on consumed state: $stateRef" } } session.save(state) } @@ -227,7 +266,7 @@ class NodeVaultService( } // we are not inside a flow, we are most likely inside a CordaService; // we will expose, by default, subscribing of -non unsubscribing- rx.Observers to rawUpdates. - return _rawUpdatesPublisher.resilientOnError() + _rawUpdatesPublisher.resilientOnError() } override val updates: Observable> @@ -639,7 +678,23 @@ class NodeVaultService( @Throws(VaultQueryException::class) override fun _queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class): Vault.Page { try { - return _queryBy(criteria, paging, sorting, contractStateType, false) + // We decrement by one if the client requests MAX_VALUE, assuming they can not notice this because they don't have enough memory + // to request MAX_VALUE states at once. + val validPaging = if (paging.pageSize == Integer.MAX_VALUE) { + paging.copy(pageSize = Integer.MAX_VALUE - 1) + } else { + checkVaultQuery(paging.pageSize >= 1) { "Page specification: invalid page size ${paging.pageSize} [minimum is 1]" } + paging + } + if (!validPaging.isDefault) { + checkVaultQuery(validPaging.pageNumber >= DEFAULT_PAGE_NUM) { + "Page specification: invalid page number ${validPaging.pageNumber} [page numbers start from $DEFAULT_PAGE_NUM]" + } + } + log.debug { "Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $validPaging, sorting: $sorting" } + return database.transaction { + queryBy(criteria, validPaging, sorting, contractStateType) + } } catch (e: VaultQueryException) { throw e } catch (e: Exception) { @@ -647,100 +702,90 @@ class NodeVaultService( } } - @Throws(VaultQueryException::class) - private fun _queryBy(criteria: QueryCriteria, paging_: PageSpecification, sorting: Sort, contractStateType: Class, skipPagingChecks: Boolean): Vault.Page { - // We decrement by one if the client requests MAX_PAGE_SIZE, assuming they can not notice this because they don't have enough memory - // to request `MAX_PAGE_SIZE` states at once. - val paging = if (paging_.pageSize == Integer.MAX_VALUE) { - paging_.copy(pageSize = Integer.MAX_VALUE - 1) - } else { - paging_ + private fun queryBy(criteria: QueryCriteria, + paging: PageSpecification, + sorting: Sort, + contractStateType: Class): Vault.Page { + // calculate total results where a page specification has been defined + val totalStatesAvailable = if (paging.isDefault) -1 else queryTotalStateCount(criteria, contractStateType) + + val (query, stateTypes) = createQuery(criteria, contractStateType, sorting) + query.setResultWindow(paging) + + val statesMetadata: MutableList = mutableListOf() + val otherResults: MutableList = mutableListOf() + + query.resultStream(paging).use { results -> + results.forEach { result -> + val result0 = result[0] + if (result0 is VaultSchemaV1.VaultStates) { + statesMetadata.add(result0.toStateMetadata()) + } else { + log.debug { "OtherResults: ${Arrays.toString(result.toArray())}" } + otherResults.addAll(result.toArray().asList()) + } + } } - log.debug { "Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $paging, sorting: $sorting" } - return database.transaction { - // calculate total results where a page specification has been defined - var totalStates = -1L - if (!skipPagingChecks && !paging.isDefault) { - val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } - val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) - val results = _queryBy(criteria.and(countCriteria), PageSpecification(), Sort(emptyList()), contractStateType, true) // only skip pagination checks for total results count query - totalStates = results.otherResults.last() as Long - } - val session = getSession() + val states: List> = servicesForResolution.loadStates( + statesMetadata.mapTo(LinkedHashSet()) { it.ref }, + ArrayList() + ) - val criteriaQuery = criteriaBuilder.createQuery(Tuple::class.java) - val queryRootVaultStates = criteriaQuery.from(VaultSchemaV1.VaultStates::class.java) - - // TODO: revisit (use single instance of parser for all queries) - val criteriaParser = HibernateQueryCriteriaParser(contractStateType, contractStateTypeMappings, criteriaBuilder, criteriaQuery, queryRootVaultStates) - - // parse criteria and build where predicates - criteriaParser.parse(criteria, sorting) - - // prepare query for execution - val query = session.createQuery(criteriaQuery) - - // pagination checks - if (!skipPagingChecks && !paging.isDefault) { - // pagination - if (paging.pageNumber < DEFAULT_PAGE_NUM) throw VaultQueryException("Page specification: invalid page number ${paging.pageNumber} [page numbers start from $DEFAULT_PAGE_NUM]") - if (paging.pageSize < 1) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [minimum is 1]") - if (paging.pageSize > MAX_PAGE_SIZE) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [maximum is $MAX_PAGE_SIZE]") - } - - // For both SQLServer and PostgresSQL, firstResult must be >= 0. So we set a floor at 0. - // TODO: This is a catch-all solution. But why is the default pageNumber set to be -1 in the first place? - // Even if we set the default pageNumber to be 1 instead, that may not cover the non-default cases. - // So the floor may be necessary anyway. - query.firstResult = maxOf(0, (paging.pageNumber - 1) * paging.pageSize) - val pageSize = paging.pageSize + 1 - query.maxResults = if (pageSize > 0) pageSize else Integer.MAX_VALUE // detection too many results, protected against overflow - - // execution - val results = query.resultList + return Vault.Page(states, statesMetadata, totalStatesAvailable, stateTypes, otherResults) + } + private fun Query.resultStream(paging: PageSpecification): Stream { + return if (paging.isDefault) { + val allResults = resultList // final pagination check (fail-fast on too many results when no pagination specified) - if (!skipPagingChecks && paging.isDefault && results.size > DEFAULT_PAGE_SIZE) { - throw VaultQueryException("There are ${results.size} results, which exceeds the limit of $DEFAULT_PAGE_SIZE for queries that do not specify paging. In order to retrieve these results, provide a `PageSpecification(pageNumber, pageSize)` to the method invoked.") + checkVaultQuery(allResults.size != paging.pageSize + 1) { + "There are more results than the limit of $DEFAULT_PAGE_SIZE for queries that do not specify paging. " + + "In order to retrieve these results, provide a PageSpecification to the method invoked." } - val statesAndRefs: MutableList> = mutableListOf() - val statesMeta: MutableList = mutableListOf() - val otherResults: MutableList = mutableListOf() - val stateRefs = mutableSetOf() - - results.asSequence() - .forEachIndexed { index, result -> - if (result[0] is VaultSchemaV1.VaultStates) { - if (!paging.isDefault && index == paging.pageSize) // skip last result if paged - return@forEachIndexed - val vaultState = result[0] as VaultSchemaV1.VaultStates - val stateRef = StateRef(SecureHash.create(vaultState.stateRef!!.txId), vaultState.stateRef!!.index) - stateRefs.add(stateRef) - statesMeta.add(Vault.StateMetadata(stateRef, - vaultState.contractStateClassName, - vaultState.recordedTime, - vaultState.consumedTime, - vaultState.stateStatus, - vaultState.notary, - vaultState.lockId, - vaultState.lockUpdateTime, - vaultState.relevancyStatus, - constraintInfo(vaultState.constraintType, vaultState.constraintData) - )) - } else { - // TODO: improve typing of returned other results - log.debug { "OtherResults: ${Arrays.toString(result.toArray())}" } - otherResults.addAll(result.toArray().asList()) - } - } - if (stateRefs.isNotEmpty()) - statesAndRefs.addAll(uncheckedCast(servicesForResolution.loadStates(stateRefs))) - - Vault.Page(states = statesAndRefs, statesMetadata = statesMeta, stateTypes = criteriaParser.stateTypes, totalStatesAvailable = totalStates, otherResults = otherResults) + allResults.stream() + } else { + stream() } } + private fun Query<*>.setResultWindow(paging: PageSpecification) { + if (paging.isDefault) { + // For both SQLServer and PostgresSQL, firstResult must be >= 0. + firstResult = 0 + // Peek ahead and see if there are more results in case pagination should be done + maxResults = paging.pageSize + 1 + } else { + firstResult = (paging.pageNumber - 1) * paging.pageSize + maxResults = paging.pageSize + } + } + + private fun queryTotalStateCount(baseCriteria: QueryCriteria, contractStateType: Class): Long { + val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } + val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) + val criteria = baseCriteria.and(countCriteria) + val (query) = createQuery(criteria, contractStateType, null) + val results = query.resultList + return results.last().toArray().last() as Long + } + + private fun createQuery(criteria: QueryCriteria, + contractStateType: Class, + sorting: Sort?): Pair, Vault.StateStatus> { + val criteriaQuery = criteriaBuilder.createQuery(Tuple::class.java) + val criteriaParser = HibernateQueryCriteriaParser( + contractStateType, + contractStateTypeMappings, + criteriaBuilder, + criteriaQuery, + criteriaQuery.from(VaultSchemaV1.VaultStates::class.java) + ) + criteriaParser.parse(criteria, sorting) + val query = getSession().createQuery(criteriaQuery) + return Pair(query, criteriaParser.stateTypes) + } + /** * Returns a [DataFeed] containing the results of the provided query, along with the associated observable, containing any subsequent updates. * @@ -775,6 +820,12 @@ class NodeVaultService( } } + private inline fun checkVaultQuery(value: Boolean, lazyMessage: () -> Any) { + if (!value) { + throw VaultQueryException(lazyMessage().toString()) + } + } + private fun filterContractStates(update: Vault.Update, contractStateType: Class) = update.copy(consumed = filterByContractState(contractStateType, update.consumed), produced = filterByContractState(contractStateType, update.produced)) @@ -802,6 +853,7 @@ class NodeVaultService( } private fun getSession() = database.currentOrNew().session + /** * Derive list from existing vault states and then incrementally update using vault observables */ diff --git a/node/src/main/kotlin/net/corda/node/services/vault/VaultSchema.kt b/node/src/main/kotlin/net/corda/node/services/vault/VaultSchema.kt index 06844d40d0..09c71fe1f7 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/VaultSchema.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/VaultSchema.kt @@ -2,7 +2,9 @@ package net.corda.node.services.vault import net.corda.core.contracts.ContractState import net.corda.core.contracts.MAX_ISSUER_REF_SIZE +import net.corda.core.contracts.StateRef import net.corda.core.contracts.UniqueIdentifier +import net.corda.core.crypto.SecureHash import net.corda.core.crypto.toStringShort import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party @@ -192,3 +194,19 @@ object VaultSchemaV1 : MappedSchema( ) : IndirectStatePersistable } +fun PersistentStateRef.toStateRef(): StateRef = StateRef(SecureHash.create(txId), index) + +fun VaultSchemaV1.VaultStates.toStateMetadata(): Vault.StateMetadata { + return Vault.StateMetadata( + stateRef!!.toStateRef(), + contractStateClassName, + recordedTime, + consumedTime, + stateStatus, + notary, + lockId, + lockUpdateTime, + relevancyStatus, + Vault.ConstraintInfo.constraintInfo(constraintType, constraintData) + ) +} diff --git a/node/src/main/kotlin/net/corda/notary/experimental/bftsmart/BFTSmartNotaryService.kt b/node/src/main/kotlin/net/corda/notary/experimental/bftsmart/BFTSmartNotaryService.kt index a570ccd7b5..76094c2a1d 100644 --- a/node/src/main/kotlin/net/corda/notary/experimental/bftsmart/BFTSmartNotaryService.kt +++ b/node/src/main/kotlin/net/corda/notary/experimental/bftsmart/BFTSmartNotaryService.kt @@ -21,6 +21,7 @@ import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.unwrap import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.transactions.PersistentUniquenessProvider +import net.corda.node.services.vault.toStateRef import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import java.security.PublicKey @@ -41,6 +42,8 @@ class BFTSmartNotaryService( ) : NotaryService() { companion object { private val log = contextLogger() + + @Suppress("unused") // Used by NotaryLoader via reflection @JvmStatic val serializationFilter get() = { clazz: Class<*> -> @@ -147,12 +150,7 @@ class BFTSmartNotaryService( toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) }, fromPersistentEntity = { //TODO null check will become obsolete after making DB/JPA columns not nullable - val txId = it.id.txId - val index = it.id.index - Pair( - StateRef(txhash = SecureHash.create(txId), index = index), - SecureHash.create(it.consumingTxHash) - ) + Pair(it.id.toStateRef(), SecureHash.create(it.consumingTxHash)) }, toPersistentEntity = { (txHash, index): StateRef, id: SecureHash -> CommittedState( diff --git a/node/src/main/kotlin/net/corda/notary/jpa/JPAUniquenessProvider.kt b/node/src/main/kotlin/net/corda/notary/jpa/JPAUniquenessProvider.kt index d38a3f35b7..b678478da6 100644 --- a/node/src/main/kotlin/net/corda/notary/jpa/JPAUniquenessProvider.kt +++ b/node/src/main/kotlin/net/corda/notary/jpa/JPAUniquenessProvider.kt @@ -24,6 +24,7 @@ import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.serialize import net.corda.core.utilities.contextLogger import net.corda.core.utilities.debug +import net.corda.node.services.vault.toStateRef import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.notary.common.InternalResult @@ -142,10 +143,6 @@ class JPAUniquenessProvider( fun encodeStateRef(s: StateRef): PersistentStateRef { return PersistentStateRef(s.txhash.toString(), s.index) } - - fun decodeStateRef(s: PersistentStateRef): StateRef { - return StateRef(txhash = SecureHash.create(s.txId), index = s.index) - } } /** @@ -215,15 +212,15 @@ class JPAUniquenessProvider( committedStates.addAll(existing) } - return committedStates.map { - val stateRef = StateRef(txhash = SecureHash.create(it.id.txId), index = it.id.index) + return committedStates.associate { + val stateRef = it.id.toStateRef() val consumingTxId = SecureHash.create(it.consumingTxHash) if (stateRef in references) { stateRef to StateConsumptionDetails(consumingTxId.reHash(), type = StateConsumptionDetails.ConsumedStateType.REFERENCE_INPUT_STATE) } else { stateRef to StateConsumptionDetails(consumingTxId.reHash()) } - }.toMap() + } } private fun withRetry(block: () -> T): T { diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/HibernateConfigurationTest.kt b/node/src/test/kotlin/net/corda/node/services/persistence/HibernateConfigurationTest.kt index 1efb349ca0..30cdbe7f59 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/HibernateConfigurationTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/HibernateConfigurationTest.kt @@ -28,12 +28,14 @@ import net.corda.finance.schemas.CashSchemaV1 import net.corda.finance.test.SampleCashSchemaV1 import net.corda.finance.test.SampleCashSchemaV2 import net.corda.finance.test.SampleCashSchemaV3 +import net.corda.node.internal.NodeServicesForResolution import net.corda.node.services.api.WritableTransactionStorage import net.corda.node.services.schema.ContractStateAndRef import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.schema.PersistentStateService import net.corda.node.services.vault.NodeVaultService import net.corda.node.services.vault.VaultSchemaV1 +import net.corda.node.services.vault.toStateRef import net.corda.node.testing.DummyFungibleContract import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig @@ -48,7 +50,6 @@ import net.corda.testing.internal.vault.VaultFiller import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import org.assertj.core.api.Assertions -import org.assertj.core.api.Assertions.`in` import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.hibernate.SessionFactory @@ -122,7 +123,14 @@ class HibernateConfigurationTest { services = object : MockServices(cordappPackages, BOB_NAME, mock().also { doReturn(null).whenever(it).verifyAndRegisterIdentity(argThat { name == BOB_NAME }) }, generateKeyPair(), dummyNotary.keyPair) { - override val vaultService = NodeVaultService(Clock.systemUTC(), keyManagementService, servicesForResolution, database, schemaService, cordappClassloader).apply { start() } + override val vaultService = NodeVaultService( + Clock.systemUTC(), + keyManagementService, + servicesForResolution as NodeServicesForResolution, + database, + schemaService, + cordappClassloader + ).apply { start() } override fun recordTransactions(statesToRecord: StatesToRecord, txs: Iterable) { for (stx in txs) { (validatedTransactions as WritableTransactionStorage).addTransaction(stx) @@ -183,7 +191,7 @@ class HibernateConfigurationTest { // execute query val queryResults = entityManager.createQuery(criteriaQuery).resultList val coins = queryResults.map { - services.loadState(toStateRef(it.stateRef!!)).data + services.loadState(it.stateRef!!.toStateRef()).data }.sumCash() assertThat(coins.toDecimal() >= BigDecimal("50.00")) } @@ -739,7 +747,7 @@ class HibernateConfigurationTest { val queryResults = entityManager.createQuery(criteriaQuery).resultList queryResults.forEach { - val cashState = services.loadState(toStateRef(it.stateRef!!)).data as Cash.State + val cashState = services.loadState(it.stateRef!!.toStateRef()).data as Cash.State println("${it.stateRef} with owner: ${cashState.owner.owningKey.toBase58String()}") } @@ -823,7 +831,7 @@ class HibernateConfigurationTest { // execute query val queryResults = entityManager.createQuery(criteriaQuery).resultList queryResults.forEach { - val cashState = services.loadState(toStateRef(it.stateRef!!)).data as Cash.State + val cashState = services.loadState(it.stateRef!!.toStateRef()).data as Cash.State println("${it.stateRef} with owner ${cashState.owner.owningKey.toBase58String()} and participants ${cashState.participants.map { it.owningKey.toBase58String() }}") } @@ -961,10 +969,6 @@ class HibernateConfigurationTest { } } - private fun toStateRef(pStateRef: PersistentStateRef): StateRef { - return StateRef(SecureHash.create(pStateRef.txId), pStateRef.index) - } - @Test(timeout=300_000) fun `schema change`() { fun createNewDB(schemas: Set, initialiseSchema: Boolean = true): CordaPersistence { diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index 1b139ab022..b06518667c 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -1674,7 +1674,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { // pagination: last page @Test(timeout=300_000) - fun `all states with paging specification - last`() { + fun `all states with paging specification - last`() { database.transaction { vaultFiller.fillWithSomeTestCash(95.DOLLARS, notaryServices, 95, DUMMY_CASH_ISSUER) // Last page implies we need to perform a row count for the Query first, @@ -1723,7 +1723,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties { @Test(timeout=300_000) fun `pagination not specified but more than default results available`() { expectedEx.expect(VaultQueryException::class.java) - expectedEx.expectMessage("provide a `PageSpecification(pageNumber, pageSize)`") + expectedEx.expectMessage("provide a PageSpecification") database.transaction { vaultFiller.fillWithSomeTestCash(201.DOLLARS, notaryServices, 201, DUMMY_CASH_ISSUER) diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt index 7e771e9904..ac621c9bff 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt @@ -10,7 +10,6 @@ import net.corda.core.flows.InitiatingFlow import net.corda.core.identity.AbstractParty import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.uncheckedCast -import net.corda.core.node.ServicesForResolution import net.corda.core.node.services.KeyManagementService import net.corda.core.node.services.queryBy import net.corda.core.node.services.vault.QueryCriteria.SoftLockingCondition @@ -29,6 +28,7 @@ import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.testing.core.singleIdentity import net.corda.testing.flows.registerCoreFlowFactory import net.corda.coretesting.internal.rigorousMock +import net.corda.node.internal.NodeServicesForResolution import net.corda.testing.node.internal.InternalMockNetwork import net.corda.testing.node.internal.enclosedCordapp import net.corda.testing.node.internal.startFlow @@ -86,7 +86,7 @@ class VaultSoftLockManagerTest { private val mockNet = InternalMockNetwork(cordappsForAllNodes = listOf(enclosedCordapp()), defaultFactory = { args -> object : InternalMockNetwork.MockNode(args) { override fun makeVaultService(keyManagementService: KeyManagementService, - services: ServicesForResolution, + services: NodeServicesForResolution, database: CordaPersistence, cordappLoader: CordappLoader): VaultServiceInternal { val node = this diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt index efd5813736..6dcb0db299 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockServices.kt @@ -26,6 +26,7 @@ import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.NetworkHostAndPort import net.corda.node.VersionInfo import net.corda.node.internal.ServicesForResolutionImpl +import net.corda.node.internal.NodeServicesForResolution import net.corda.node.internal.cordapp.JarScanningCordappLoader import net.corda.node.services.api.* import net.corda.node.services.diagnostics.NodeDiagnosticsService @@ -460,7 +461,14 @@ open class MockServices private constructor( get() = ServicesForResolutionImpl(identityService, attachments, cordappProvider, networkParametersService, validatedTransactions) internal fun makeVaultService(schemaService: SchemaService, database: CordaPersistence, cordappLoader: CordappLoader): VaultServiceInternal { - return NodeVaultService(clock, keyManagementService, servicesForResolution, database, schemaService, cordappLoader.appClassLoader).apply { start() } + return NodeVaultService( + clock, + keyManagementService, + servicesForResolution as NodeServicesForResolution, + database, + schemaService, + cordappLoader.appClassLoader + ).apply { start() } } // This needs to be internal as MutableClassToInstanceMap is a guava type and shouldn't be part of our public API From dfcc7019dcb0adecd89a6017e0905b1f0c188e61 Mon Sep 17 00:00:00 2001 From: Connel McGovern <100574906+mcgovc@users.noreply.github.com> Date: Fri, 2 Jun 2023 17:53:24 +0100 Subject: [PATCH 07/12] ES-562: Correct modules to scan for C4 OS Snyk scan nightly --- .ci/dev/nightly-regression/JenkinsfileSnykScan | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/dev/nightly-regression/JenkinsfileSnykScan b/.ci/dev/nightly-regression/JenkinsfileSnykScan index 564bb516a9..6c0f81d698 100644 --- a/.ci/dev/nightly-regression/JenkinsfileSnykScan +++ b/.ci/dev/nightly-regression/JenkinsfileSnykScan @@ -3,5 +3,5 @@ cordaSnykScanPipeline ( snykTokenId: 'c4-os-snyk-api-token-secret', // specify the Gradle submodules to scan and monitor on snyk Server - modulesToScan: ['node', 'capsule', 'bridge', 'bridgecapsule'] + modulesToScan: ['node', 'capsule'] ) From ac9f3c150fd9983bfbb8fe8097e233f1c801f7a1 Mon Sep 17 00:00:00 2001 From: Connel McGovern Date: Fri, 2 Jun 2023 17:53:24 +0100 Subject: [PATCH 08/12] Include 'ES' jira code in PR title check --- .github/workflows/check-pr-title.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check-pr-title.yml b/.github/workflows/check-pr-title.yml index a27b6c02e4..f99824a302 100644 --- a/.github/workflows/check-pr-title.yml +++ b/.github/workflows/check-pr-title.yml @@ -9,6 +9,6 @@ jobs: steps: - uses: morrisoncole/pr-lint-action@v1.4.1 with: - title-regex: '^((CORDA|AG|EG|ENT|INFRA|NAAS)-\d+|NOTICK)(.*)' + title-regex: '^((CORDA|AG|EG|ENT|INFRA|ES)-\d+|NOTICK)(.*)' on-failed-regex-comment: "PR title failed to match regex -> `%regex%`" repo-token: "${{ secrets.GITHUB_TOKEN }}" From 596b51c2fd58b90a8a6e65d737a384edb286ea00 Mon Sep 17 00:00:00 2001 From: Connel McGovern Date: Tue, 6 Jun 2023 16:43:28 +0100 Subject: [PATCH 09/12] Removing bridge/bridgecapsule from main release branch CI pipeline --- .ci/dev/regression/Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/dev/regression/Jenkinsfile b/.ci/dev/regression/Jenkinsfile index 1ce7245e68..1c3b7ed64c 100644 --- a/.ci/dev/regression/Jenkinsfile +++ b/.ci/dev/regression/Jenkinsfile @@ -90,7 +90,7 @@ pipeline { steps { script { // Invoke Snyk for each Gradle sub project we wish to scan - def modulesToScan = ['node', 'capsule', 'bridge', 'bridgecapsule'] + def modulesToScan = ['node', 'capsule'] modulesToScan.each { module -> snykSecurityScan("${env.SNYK_API_KEY}", "--sub-project=$module --configuration-matching='^runtimeClasspath\$' --prune-repeated-subdependencies --debug --target-reference='${env.BRANCH_NAME}' --project-tags=Branch='${env.BRANCH_NAME.replaceAll("[^0-9|a-z|A-Z]+","_")}'") } From 2246c94fd5bc3c4a909f207ee0a9217823affe89 Mon Sep 17 00:00:00 2001 From: Connel McGovern <100574906+mcgovc@users.noreply.github.com> Date: Tue, 6 Jun 2023 17:08:19 +0100 Subject: [PATCH 10/12] ES-562: Updating .snyk YAML indentation & updating modules to scan on Snyk nightly (#7385) * NOTICK: Correct Yaml whitespace * Update JenkinsfileSnykScan Snyk modules * Correcting YAML indentation * NOTICK: Update reges to match ES Jira tickets * Removing bridge/bridgecapsule from main release branch CI pipeline --- .ci/dev/nightly-regression/JenkinsfileSnykScan | 2 +- .ci/dev/regression/Jenkinsfile | 2 +- .github/workflows/check-pr-title.yml | 2 +- .snyk | 14 +++++++------- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.ci/dev/nightly-regression/JenkinsfileSnykScan b/.ci/dev/nightly-regression/JenkinsfileSnykScan index 564bb516a9..6c0f81d698 100644 --- a/.ci/dev/nightly-regression/JenkinsfileSnykScan +++ b/.ci/dev/nightly-regression/JenkinsfileSnykScan @@ -3,5 +3,5 @@ cordaSnykScanPipeline ( snykTokenId: 'c4-os-snyk-api-token-secret', // specify the Gradle submodules to scan and monitor on snyk Server - modulesToScan: ['node', 'capsule', 'bridge', 'bridgecapsule'] + modulesToScan: ['node', 'capsule'] ) diff --git a/.ci/dev/regression/Jenkinsfile b/.ci/dev/regression/Jenkinsfile index 02dc1a403d..4bab8e416c 100644 --- a/.ci/dev/regression/Jenkinsfile +++ b/.ci/dev/regression/Jenkinsfile @@ -92,7 +92,7 @@ pipeline { steps { script { // Invoke Snyk for each Gradle sub project we wish to scan - def modulesToScan = ['node', 'capsule', 'bridge', 'bridgecapsule'] + def modulesToScan = ['node', 'capsule'] modulesToScan.each { module -> snykSecurityScan("${env.SNYK_API_KEY}", "--sub-project=$module --configuration-matching='^runtimeClasspath\$' --prune-repeated-subdependencies --debug --target-reference='${env.BRANCH_NAME}' --project-tags=Branch='${env.BRANCH_NAME.replaceAll("[^0-9|a-z|A-Z]+","_")}'") } diff --git a/.github/workflows/check-pr-title.yml b/.github/workflows/check-pr-title.yml index a27b6c02e4..6d45a2bd31 100644 --- a/.github/workflows/check-pr-title.yml +++ b/.github/workflows/check-pr-title.yml @@ -9,6 +9,6 @@ jobs: steps: - uses: morrisoncole/pr-lint-action@v1.4.1 with: - title-regex: '^((CORDA|AG|EG|ENT|INFRA|NAAS)-\d+|NOTICK)(.*)' + title-regex: '^((CORDA|AG|EG|ENT|INFRA|NAAS|ES)-\d+|NOTICK)(.*)' on-failed-regex-comment: "PR title failed to match regex -> `%regex%`" repo-token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.snyk b/.snyk index 2b9605267a..93a9db4572 100644 --- a/.snyk +++ b/.snyk @@ -131,7 +131,7 @@ ignore: this vulnerability. expires: 2023-09-01T11:32:38.120Z created: 2022-09-21T11:32:38.125Z -SNYK-JAVA-COMFASTERXMLJACKSONCORE-3038424: + SNYK-JAVA-COMFASTERXMLJACKSONCORE-3038424: - '*': reason: >- Corda does not set the non-default UNWRAP_SINGLE_VALUE_ARRAYS required @@ -145,7 +145,7 @@ SNYK-JAVA-COMFASTERXMLJACKSONCORE-3038424: nesting are potentially susceptible. expires: 2023-09-01T12:04:40.180Z created: 2023-02-09T12:04:40.209Z - SNYK-JAVA-COMFASTERXMLJACKSONCORE-3038426: + SNYK-JAVA-COMFASTERXMLJACKSONCORE-3038426: - '*': reason: >- Corda does not set the non-default UNWRAP_SINGLE_VALUE_ARRAYS required @@ -159,7 +159,7 @@ SNYK-JAVA-COMFASTERXMLJACKSONCORE-3038424: nesting are potentially susceptible. expires: 2023-09-01T12:05:03.931Z created: 2023-02-09T12:05:03.962Z - SNYK-JAVA-ORGYAML-2806360: + SNYK-JAVA-ORGYAML-2806360: - '*': reason: >- Snakeyaml is being used by Jackson and liquidbase. Corda does not use @@ -172,7 +172,7 @@ SNYK-JAVA-COMFASTERXMLJACKSONCORE-3038424: not exposed to this DOS vulnerability. expires: 2023-09-01T13:40:55.262Z created: 2022-09-21T13:40:55.279Z - SNYK-JAVA-ORGYAML-3016891: + SNYK-JAVA-ORGYAML-3016891: - '*': reason: >- Snakeyaml is being used by Jackson and liquidbase. Corda does not use @@ -186,7 +186,7 @@ SNYK-JAVA-COMFASTERXMLJACKSONCORE-3038424: vulnerability. expires: 2023-09-01T16:37:28.911Z created: 2023-02-06T16:37:28.933Z - SNYK-JAVA-ORGYAML-3016888: + SNYK-JAVA-ORGYAML-3016888: - '*': reason: >- Snakeyaml is being used by Jackson and liquidbase. Corda does not use @@ -200,7 +200,7 @@ SNYK-JAVA-COMFASTERXMLJACKSONCORE-3038424: vulnerability. expires: 2023-09-01T13:39:49.450Z created: 2022-09-21T13:39:49.470Z - SNYK-JAVA-ORGYAML-3016889: + SNYK-JAVA-ORGYAML-3016889: - '*': reason: >- Snakeyaml is being used by Jackson and liquidbase. Corda does not use @@ -214,7 +214,7 @@ SNYK-JAVA-COMFASTERXMLJACKSONCORE-3038424: vulnerability. expires: 2023-09-01T16:35:13.840Z created: 2023-02-06T16:35:13.875Z - SNYK-JAVA-ORGYAML-3113851: + SNYK-JAVA-ORGYAML-3113851: - '*': reason: >- Snakeyaml is being used by Jackson and liquidbase. Corda does not use From 5b3180bf7bb66d50ae68dc51bc8e2e5f651d27d0 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Fri, 9 Jun 2023 11:17:26 +0100 Subject: [PATCH 11/12] ENT-10016: Give all node threads descriptive names --- .../net/corda/client/rpc/CordaRPCClient.kt | 3 +- .../rpc/internal/ReconnectingCordaRPCOps.kt | 4 +- .../internal/ArtemisMessagingClient.kt | 4 +- .../nodeapi/internal/ArtemisTcpTransport.kt | 11 ++- .../corda/nodeapi/internal/ArtemisUtils.kt | 32 +++++++ ...ctory.kt => CordaNettyConnectorFactory.kt} | 22 ++++- .../internal/bridging/AMQPBridgeManager.kt | 96 ++++++++++--------- .../bridging/BridgeControlListener.kt | 4 +- .../bridging/LoopbackBridgeManager.kt | 4 +- .../protonwrapper/netty/AMQPClient.kt | 63 ++++++++---- .../CertificateRevocationListNodeTests.kt | 2 +- .../net/corda/node/amqp/ProtonWrapperTests.kt | 2 +- .../net/corda/node/internal/AbstractNode.kt | 5 +- .../kotlin/net/corda/node/internal/Node.kt | 8 +- .../services/events/NodeSchedulerService.kt | 3 +- .../messaging/ArtemisMessagingServer.kt | 2 +- .../messaging/NodeNettyAcceptorFactory.kt | 22 ++++- .../services/network/NetworkMapUpdater.kt | 9 +- .../rpc/InternalRPCMessagingClient.kt | 3 +- .../services/rpc/RpcBrokerConfiguration.kt | 4 +- .../node/services/statemachine/FlowMonitor.kt | 7 +- 21 files changed, 203 insertions(+), 107 deletions(-) rename node-api/src/main/kotlin/net/corda/nodeapi/internal/{NodeNettyConnectorFactory.kt => CordaNettyConnectorFactory.kt} (70%) diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt index d90befe6ae..d008a351dc 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt @@ -1,5 +1,6 @@ package net.corda.client.rpc +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.ReconnectingCordaRPCOps import net.corda.client.rpc.internal.SerializationEnvironmentHelper @@ -52,7 +53,7 @@ class CordaRPCConnection private constructor( sslConfiguration: ClientRpcSslOptions? = null, classLoader: ClassLoader? = null ): CordaRPCConnection { - val observersPool: ExecutorService = Executors.newCachedThreadPool() + val observersPool: ExecutorService = Executors.newCachedThreadPool(DefaultThreadFactory("RPCObserver")) return CordaRPCConnection(null, observersPool, ReconnectingCordaRPCOps( addresses, username, diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/ReconnectingCordaRPCOps.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/ReconnectingCordaRPCOps.kt index 71964d961e..005ac70fd3 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/ReconnectingCordaRPCOps.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/ReconnectingCordaRPCOps.kt @@ -1,5 +1,6 @@ package net.corda.client.rpc.internal +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.client.rpc.ConnectionFailureException import net.corda.client.rpc.CordaRPCClient import net.corda.client.rpc.CordaRPCClientConfiguration @@ -99,7 +100,8 @@ class ReconnectingCordaRPCOps private constructor( ErrorInterceptingHandler(reconnectingRPCConnection)) as CordaRPCOps } } - private val retryFlowsPool = Executors.newScheduledThreadPool(1) + private val retryFlowsPool = Executors.newScheduledThreadPool(1, DefaultThreadFactory("FlowRetry")) + /** * This function runs a flow and retries until it completes successfully. * diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt index 721b856fdd..74d580a827 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt @@ -42,8 +42,8 @@ class ArtemisMessagingClient(private val config: MutualSslConfiguration, override fun start(): Started = synchronized(this) { check(started == null) { "start can't be called twice" } val tcpTransport = p2pConnectorTcpTransport(serverAddress, config, threadPoolName = threadPoolName, trace = trace) - val backupTransports = backupServerAddressPool.map { - p2pConnectorTcpTransport(it, config, threadPoolName = threadPoolName, trace = trace) + val backupTransports = backupServerAddressPool.mapIndexed { index, address -> + p2pConnectorTcpTransport(address, config, threadPoolName = "$threadPoolName-backup${index+1}", trace = trace) } log.info("Connecting to message broker: $serverAddress") diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt index 3b5adaf934..84d63df5e2 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisTcpTransport.kt @@ -122,6 +122,7 @@ class ArtemisTcpTransport { fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: BrokerRpcSslOptions?, enableSSL: Boolean = true, + threadPoolName: String = "RPCServer", trace: Boolean = false, remotingThreads: Int? = null): TransportConfiguration { val options = mutableMapOf() @@ -129,7 +130,7 @@ class ArtemisTcpTransport { config.keyStorePath.requireOnDefaultFileSystem() options.putAll(config.toTransportOptions()) } - return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, null, enableSSL, "RPCServer", trace, remotingThreads) + return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, null, enableSSL, threadPoolName, trace, remotingThreads) } fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort, @@ -147,14 +148,16 @@ class ArtemisTcpTransport { fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration, + threadPoolName: String = "Internal-RPCClient", trace: Boolean = false): TransportConfiguration { val options = mutableMapOf() config.addToTransportOptions(options) - return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace, null) + return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, threadPoolName, trace, null) } fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration, + threadPoolName: String = "Internal-RPCServer", trace: Boolean = false, remotingThreads: Int? = null): TransportConfiguration { val options = mutableMapOf() @@ -165,7 +168,7 @@ class ArtemisTcpTransport { options, trustManagerFactory(requireNotNull(config.trustStore).get()), true, - "Internal-RPCServer", + threadPoolName, trace, remotingThreads ) @@ -209,7 +212,7 @@ class ArtemisTcpTransport { trace: Boolean, remotingThreads: Int?): TransportConfiguration { return createTransport( - NodeNettyConnectorFactory::class.java.name, + CordaNettyConnectorFactory::class.java.name, hostAndPort, protocols, options, diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisUtils.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisUtils.kt index 23bb9d1428..a3c2109d32 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisUtils.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisUtils.kt @@ -1,8 +1,14 @@ @file:JvmName("ArtemisUtils") package net.corda.nodeapi.internal +import net.corda.core.internal.declaredField +import org.apache.activemq.artemis.utils.actors.ProcessorBase import java.nio.file.FileSystems import java.nio.file.Path +import java.util.concurrent.Executor +import java.util.concurrent.ThreadFactory +import java.util.concurrent.ThreadPoolExecutor +import java.util.concurrent.atomic.AtomicInteger /** * Require that the [Path] is on a default file system, and therefore is one that Artemis is willing to use. @@ -16,3 +22,29 @@ fun requireMessageSize(messageSize: Int, limit: Int) { require(messageSize <= limit) { "Message exceeds maxMessageSize network parameter, maxMessageSize: [$limit], message size: [$messageSize]" } } +val Executor.rootExecutor: Executor get() { + var executor: Executor = this + while (executor is ProcessorBase<*>) { + executor = executor.declaredField("delegate").value + } + return executor +} + +fun Executor.setThreadPoolName(threadPoolName: String) { + (rootExecutor as? ThreadPoolExecutor)?.let { it.threadFactory = NamedThreadFactory(threadPoolName, it.threadFactory) } +} + +private class NamedThreadFactory(poolName: String, private val delegate: ThreadFactory) : ThreadFactory { + companion object { + private val poolId = AtomicInteger(0) + } + + private val prefix = "$poolName-${poolId.incrementAndGet()}-" + private val nextId = AtomicInteger(0) + + override fun newThread(r: Runnable): Thread { + val thread = delegate.newThread(r) + thread.name = "$prefix${nextId.incrementAndGet()}" + return thread + } +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeNettyConnectorFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/CordaNettyConnectorFactory.kt similarity index 70% rename from node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeNettyConnectorFactory.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/CordaNettyConnectorFactory.kt index 47e046566e..a9bdc519a9 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/NodeNettyConnectorFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/CordaNettyConnectorFactory.kt @@ -14,15 +14,16 @@ import org.apache.activemq.artemis.utils.ConfigurationHelper import java.util.concurrent.Executor import java.util.concurrent.ScheduledExecutorService -class NodeNettyConnectorFactory : ConnectorFactory { +class CordaNettyConnectorFactory : ConnectorFactory { override fun createConnector(configuration: MutableMap?, handler: BufferHandler?, listener: ClientConnectionLifeCycleListener?, - closeExecutor: Executor?, - threadPool: Executor?, - scheduledThreadPool: ScheduledExecutorService?, + closeExecutor: Executor, + threadPool: Executor, + scheduledThreadPool: ScheduledExecutorService, protocolManager: ClientProtocolManager?): Connector { val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Connector", configuration) + setThreadPoolName(threadPool, closeExecutor, scheduledThreadPool, threadPoolName) val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration) return NettyConnector( configuration, @@ -31,7 +32,7 @@ class NodeNettyConnectorFactory : ConnectorFactory { closeExecutor, threadPool, scheduledThreadPool, - MyClientProtocolManager(threadPoolName, trace) + MyClientProtocolManager("$threadPoolName-netty", trace) ) } @@ -39,6 +40,17 @@ class NodeNettyConnectorFactory : ConnectorFactory { override fun getDefaults(): Map = NettyConnector.DEFAULT_CONFIG + private fun setThreadPoolName(threadPool: Executor, closeExecutor: Executor, scheduledThreadPool: ScheduledExecutorService, name: String) { + threadPool.setThreadPoolName("$name-artemis") + // Artemis will actually wrap the same backing Executor to create multiple "OrderedExecutors". In this scenerio both the threadPool + // and the closeExecutor are the same when it comes to the pool names. If however they are different then given them separate names. + if (threadPool.rootExecutor !== closeExecutor.rootExecutor) { + closeExecutor.setThreadPoolName("$name-artemis-closer") + } + // The scheduler is separate + scheduledThreadPool.setThreadPoolName("$name-artemis-scheduler") + } + private class MyClientProtocolManager(private val threadPoolName: String, private val trace: Boolean) : ActiveMQClientProtocolManager() { override fun addChannelHandlers(pipeline: ChannelPipeline) { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt index deb6ef999a..93ab5616de 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/AMQPBridgeManager.kt @@ -22,6 +22,7 @@ import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig +import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE @@ -31,6 +32,7 @@ import org.apache.activemq.artemis.api.core.client.ClientSession import org.slf4j.MDC import rx.Subscription import java.time.Duration +import java.util.concurrent.ExecutorService import java.util.concurrent.Executors import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.ScheduledFuture @@ -53,7 +55,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, maxMessageSize: Int, revocationConfig: RevocationConfig, enableSNI: Boolean, - private val artemisMessageClientFactory: () -> ArtemisSessionProvider, + private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider, private val bridgeMetricsService: BridgeMetricsService? = null, trace: Boolean, sslHandshakeTimeout: Duration?, @@ -78,9 +80,11 @@ open class AMQPBridgeManager(keyStore: CertificateStore, private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig,useOpenSSL, enableSNI, trace = trace, _sslHandshakeTimeout = sslHandshakeTimeout) private var sharedEventLoopGroup: EventLoopGroup? = null + private var sslDelegatedTaskExecutor: ExecutorService? = null private var artemis: ArtemisSessionProvider? = null companion object { + private val log = contextLogger() private const val CORDA_NUM_BRIDGE_THREADS_PROP_NAME = "net.corda.nodeapi.amqpbridgemanager.NumBridgeThreads" @@ -97,18 +101,11 @@ open class AMQPBridgeManager(keyStore: CertificateStore, * however Artemis and the remote Corda instanced will deduplicate these messages. */ @Suppress("TooManyFunctions") - private class AMQPBridge(val sourceX500Name: String, - val queueName: String, - val targets: List, - val allowedRemoteLegalNames: Set, - private val amqpConfig: AMQPConfiguration, - sharedEventGroup: EventLoopGroup, - private val artemis: ArtemisSessionProvider, - private val bridgeMetricsService: BridgeMetricsService?, - private val bridgeConnectionTTLSeconds: Int) { - companion object { - private val log = contextLogger() - } + private inner class AMQPBridge(val sourceX500Name: String, + val queueName: String, + val targets: List, + val allowedRemoteLegalNames: Set, + private val amqpConfig: AMQPConfiguration) { private fun withMDC(block: () -> Unit) { val oldMDC = MDC.getCopyOfContextMap() ?: emptyMap() @@ -134,13 +131,18 @@ open class AMQPBridgeManager(keyStore: CertificateStore, private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) } - val amqpClient = AMQPClient(targets, allowedRemoteLegalNames, amqpConfig, sharedThreadPool = sharedEventGroup) + val amqpClient = AMQPClient( + targets, + allowedRemoteLegalNames, + amqpConfig, + AMQPClient.NettyThreading.Shared(sharedEventLoopGroup!!, sslDelegatedTaskExecutor!!) + ) private var session: ClientSession? = null private var consumer: ClientConsumer? = null private var connectedSubscription: Subscription? = null @Volatile private var messagesReceived: Boolean = false - private val eventLoop: EventLoop = sharedEventGroup.next() + private val eventLoop: EventLoop = sharedEventLoopGroup!!.next() private var artemisState: ArtemisState = ArtemisState.STOPPED set(value) { logDebugWithMDC { "State change $field to $value" } @@ -152,32 +154,9 @@ open class AMQPBridgeManager(keyStore: CertificateStore, private var scheduledExecutorService: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor(ThreadFactoryBuilder().setNameFormat("bridge-connection-reset-%d").build()) - @Suppress("ClassNaming") - private sealed class ArtemisState { - object STARTING : ArtemisState() - data class STARTED(override val pending: ScheduledFuture) : ArtemisState() - - object CHECKING : ArtemisState() - object RESTARTED : ArtemisState() - object RECEIVING : ArtemisState() - - object AMQP_STOPPED : ArtemisState() - object AMQP_STARTING : ArtemisState() - object AMQP_STARTED : ArtemisState() - object AMQP_RESTARTED : ArtemisState() - - object STOPPING : ArtemisState() - object STOPPED : ArtemisState() - data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture) : ArtemisState() - - open val pending: ScheduledFuture? = null - - override fun toString(): String = javaClass.simpleName - } - private fun artemis(inProgress: ArtemisState, block: (precedingState: ArtemisState) -> ArtemisState) { val runnable = { - synchronized(artemis) { + synchronized(artemis!!) { try { val precedingState = artemisState artemisState.pending?.cancel(false) @@ -253,7 +232,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, } } artemis(ArtemisState.STARTING) { - val startedArtemis = artemis.started + val startedArtemis = artemis!!.started if (startedArtemis == null) { logInfoWithMDC("Bridge Connected but Artemis is disconnected") ArtemisState.STOPPED @@ -457,6 +436,29 @@ open class AMQPBridgeManager(keyStore: CertificateStore, } } + @Suppress("ClassNaming") + private sealed class ArtemisState { + object STARTING : ArtemisState() + data class STARTED(override val pending: ScheduledFuture) : ArtemisState() + + object CHECKING : ArtemisState() + object RESTARTED : ArtemisState() + object RECEIVING : ArtemisState() + + object AMQP_STOPPED : ArtemisState() + object AMQP_STARTING : ArtemisState() + object AMQP_STARTED : ArtemisState() + object AMQP_RESTARTED : ArtemisState() + + object STOPPING : ArtemisState() + object STOPPED : ArtemisState() + data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture) : ArtemisState() + + open val pending: ScheduledFuture? = null + + override fun toString(): String = javaClass.simpleName + } + override fun deployBridge(sourceX500Name: String, queueName: String, targets: List, legalNames: Set) { lock.withLock { val bridges = queueNamesToBridgesMap.getOrPut(queueName) { mutableListOf() } @@ -467,8 +469,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore, } val newAMQPConfig = with(amqpConfig) { AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig, useOpenSsl, enableSNI, sourceX500Name, trace, sslHandshakeTimeout) } - val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig, sharedEventLoopGroup!!, artemis!!, - bridgeMetricsService, bridgeConnectionTTLSeconds) + val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig) bridges += newBridge bridgeMetricsService?.bridgeCreated(targets, legalNames) newBridge @@ -497,15 +498,16 @@ open class AMQPBridgeManager(keyStore: CertificateStore, // queueNamesToBridgesMap returns a mutable list, .toList converts it to a immutable list so it won't be changed by the [destroyBridge] method. val bridges = queueNamesToBridgesMap[queueName]?.toList() destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList()) - bridges?.map { + bridges?.associate { it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.allowedRemoteLegalNames.toList(), serviceAddress = false) - }?.toMap() ?: emptyMap() + } ?: emptyMap() } } override fun start() { - sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS, DefaultThreadFactory("AMQPBridge", Thread.MAX_PRIORITY)) - val artemis = artemisMessageClientFactory() + sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS, DefaultThreadFactory("NettyBridge", Thread.MAX_PRIORITY)) + sslDelegatedTaskExecutor = sslDelegatedTaskExecutor("NettyBridge") + val artemis = artemisMessageClientFactory("ArtemisBridge") this.artemis = artemis artemis.start() } @@ -522,6 +524,8 @@ open class AMQPBridgeManager(keyStore: CertificateStore, sharedEventLoopGroup = null queueNamesToBridgesMap.clear() artemis?.stop() + sslDelegatedTaskExecutor?.shutdown() + sslDelegatedTaskExecutor = null } } } \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt index 0a402d8854..708588cb63 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/BridgeControlListener.kt @@ -34,7 +34,7 @@ class BridgeControlListener(private val keyStore: CertificateStore, maxMessageSize: Int, revocationConfig: RevocationConfig, enableSNI: Boolean, - private val artemisMessageClientFactory: () -> ArtemisSessionProvider, + private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider, bridgeMetricsService: BridgeMetricsService? = null, trace: Boolean = false, sslHandshakeTimeout: Duration? = null, @@ -79,7 +79,7 @@ class BridgeControlListener(private val keyStore: CertificateStore, bridgeNotifyQueue = "$BRIDGE_NOTIFY.$queueDisambiguityId" bridgeManager.start() - val artemis = artemisMessageClientFactory() + val artemis = artemisMessageClientFactory("BridgeControl") this.artemis = artemis artemis.start() val artemisClient = artemis.started!! diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/LoopbackBridgeManager.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/LoopbackBridgeManager.kt index e9ac1ca522..2dd9f8bff0 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/LoopbackBridgeManager.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/bridging/LoopbackBridgeManager.kt @@ -37,7 +37,7 @@ class LoopbackBridgeManager(keyStore: CertificateStore, maxMessageSize: Int, revocationConfig: RevocationConfig, enableSNI: Boolean, - private val artemisMessageClientFactory: () -> ArtemisSessionProvider, + private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider, private val bridgeMetricsService: BridgeMetricsService? = null, private val isLocalInbox: (String) -> Boolean, trace: Boolean, @@ -204,7 +204,7 @@ class LoopbackBridgeManager(keyStore: CertificateStore, override fun start() { super.start() - val artemis = artemisMessageClientFactory() + val artemis = artemisMessageClientFactory("LoopbackBridge") this.artemis = artemis artemis.start() } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt index 3c18830147..c502817029 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt @@ -32,7 +32,9 @@ import rx.Observable import rx.subjects.PublishSubject import java.lang.Long.min import java.net.InetSocketAddress +import java.util.concurrent.Executor import java.util.concurrent.ExecutorService +import java.util.concurrent.ThreadPoolExecutor import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock import kotlin.concurrent.withLock @@ -61,8 +63,7 @@ data class ProxyConfig(val version: ProxyVersion, val proxyAddress: NetworkHostA class AMQPClient(private val targets: List, val allowedRemoteLegalNames: Set, private val configuration: AMQPConfiguration, - private val sharedThreadPool: EventLoopGroup? = null, - private val threadPoolName: String = "AMQPClient", + private val nettyThreading: NettyThreading = NettyThreading.NonShared("AMQPClient"), private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON) : AutoCloseable { companion object { init { @@ -82,7 +83,6 @@ class AMQPClient(private val targets: List, private val lock = ReentrantLock() @Volatile private var started: Boolean = false - private var workerGroup: EventLoopGroup? = null @Volatile private var clientChannel: Channel? = null // Offset into the list of targets, so that we can implement round-robin reconnect logic. @@ -94,7 +94,6 @@ class AMQPClient(private val targets: List, private var amqpActive = false @Volatile private var amqpChannelHandler: ChannelHandler? = null - private var sslDelegatedTaskExecutor: ExecutorService? = null val localAddressString: String get() = clientChannel?.localAddress()?.toString() ?: "" @@ -123,7 +122,7 @@ class AMQPClient(private val targets: List, log.info("Failed to connect to $currentTarget", future.cause()) if (started) { - workerGroup?.schedule({ + nettyThreading.eventLoopGroup.schedule({ nextTarget() restart() }, retryInterval, TimeUnit.MILLISECONDS) @@ -142,7 +141,7 @@ class AMQPClient(private val targets: List, clientChannel = null if (started && !amqpActive) { log.debug { "Scheduling restart of $currentTarget (AMQP inactive)" } - workerGroup?.schedule({ + nettyThreading.eventLoopGroup.schedule({ nextTarget() restart() }, retryInterval, TimeUnit.MILLISECONDS) @@ -198,7 +197,6 @@ class AMQPClient(private val targets: List, val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration) val target = parent.currentTarget - val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor) val handler = if (parent.configuration.useOpenSsl) { createClientOpenSslHandler( target, @@ -206,7 +204,7 @@ class AMQPClient(private val targets: List, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc(), - delegatedTaskExecutor + parent.nettyThreading.sslDelegatedTaskExecutor ) } else { createClientSslHandler( @@ -214,7 +212,7 @@ class AMQPClient(private val targets: List, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, - delegatedTaskExecutor + parent.nettyThreading.sslDelegatedTaskExecutor ) } handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis() @@ -256,7 +254,7 @@ class AMQPClient(private val targets: List, if (started && amqpActive) { log.debug { "Scheduling restart of $currentTarget (AMQP active)" } - workerGroup?.schedule({ + nettyThreading.eventLoopGroup.schedule({ nextTarget() restart() }, retryInterval, TimeUnit.MILLISECONDS) @@ -273,8 +271,7 @@ class AMQPClient(private val targets: List, return } log.info("Connect to: $currentTarget") - sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName) - workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY)) + (nettyThreading as? NettyThreading.NonShared)?.start() started = true restart() } @@ -286,7 +283,7 @@ class AMQPClient(private val targets: List, } val bootstrap = Bootstrap() // TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux - bootstrap.group(workerGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this)) + bootstrap.group(nettyThreading.eventLoopGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this)) // Delegate DNS Resolution to the proxy side, if we are using proxy. if (configuration.proxyConfig != null) { bootstrap.resolver(NoopAddressResolverGroup.INSTANCE) @@ -300,16 +297,12 @@ class AMQPClient(private val targets: List, lock.withLock { log.info("Stopping connection to: $currentTarget, Local address: $localAddressString") started = false - if (sharedThreadPool == null) { - workerGroup?.shutdownGracefully() - workerGroup?.terminationFuture()?.sync() + if (nettyThreading is NettyThreading.NonShared) { + nettyThreading.stop() } else { clientChannel?.close()?.sync() } clientChannel = null - workerGroup = null - sslDelegatedTaskExecutor?.shutdown() - sslDelegatedTaskExecutor = null log.info("Stopped connection to $currentTarget") } } @@ -350,4 +343,36 @@ class AMQPClient(private val targets: List, private val _onConnection = PublishSubject.create().toSerialized() val onConnection: Observable get() = _onConnection + + + sealed class NettyThreading { + abstract val eventLoopGroup: EventLoopGroup + abstract val sslDelegatedTaskExecutor: Executor + + class Shared(override val eventLoopGroup: EventLoopGroup, + override val sslDelegatedTaskExecutor: ExecutorService = sslDelegatedTaskExecutor("AMQPClient")) : NettyThreading() + + class NonShared(val threadPoolName: String) : NettyThreading() { + private var _eventLoopGroup: NioEventLoopGroup? = null + override val eventLoopGroup: EventLoopGroup get() = checkNotNull(_eventLoopGroup) + + private var _sslDelegatedTaskExecutor: ThreadPoolExecutor? = null + override val sslDelegatedTaskExecutor: ExecutorService get() = checkNotNull(_sslDelegatedTaskExecutor) + + fun start() { + check(_eventLoopGroup == null) + check(_sslDelegatedTaskExecutor == null) + _eventLoopGroup = NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY)) + _sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName) + } + + fun stop() { + eventLoopGroup.shutdownGracefully() + eventLoopGroup.terminationFuture().sync() + sslDelegatedTaskExecutor.shutdown() + _eventLoopGroup = null + _sslDelegatedTaskExecutor = null + } + } + } } diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt index 06e12a966f..ddffb79506 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/CertificateRevocationListNodeTests.kt @@ -300,7 +300,7 @@ abstract class AbstractServerRevocationTest { listOf(NetworkHostAndPort("localhost", targetPort)), setOf(CHARLIE_NAME), amqpConfig, - threadPoolName = legalName.organisation, + nettyThreading = AMQPClient.NettyThreading.NonShared(legalName.organisation), distPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout) ) amqpClients += amqpClient diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt index 89672379d7..f6e4e1d4ed 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt @@ -503,7 +503,7 @@ class ProtonWrapperTests { listOf(NetworkHostAndPort("localhost", serverPort)), setOf(ALICE_NAME), amqpConfig, - sharedThreadPool = sharedEventGroup) + nettyThreading = AMQPClient.NettyThreading.Shared(sharedEventGroup)) } private fun createServer(port: Int, diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 7be9b94dc9..ff33390bca 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -5,6 +5,7 @@ import com.codahale.metrics.MetricRegistry import com.google.common.collect.MutableClassToInstanceMap import com.google.common.util.concurrent.MoreExecutors import com.zaxxer.hikari.pool.HikariPool +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.common.logging.errorReporting.NodeDatabaseErrors import net.corda.confidential.SwapIdentitiesFlow import net.corda.core.CordaException @@ -334,7 +335,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, private val schedulerService = makeNodeSchedulerService() private val cordappServices = MutableClassToInstanceMap.create() - private val shutdownExecutor = Executors.newSingleThreadExecutor() + private val shutdownExecutor = Executors.newSingleThreadExecutor(DefaultThreadFactory("Shutdown")) protected abstract val transactionVerifierWorkerCount: Int /** @@ -770,7 +771,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, } else { 1.days } - val executor = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("Network Map Updater")) + val executor = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("NetworkMapPublisher")) executor.submit(object : Runnable { override fun run() { val republishInterval = try { diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index 4ee0da30ec..bd3b6cb744 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -415,12 +415,13 @@ open class Node(configuration: NodeConfiguration, } private fun makeBridgeControlListener(serverAddress: NetworkHostAndPort, networkParameters: NetworkParameters): BridgeControlListener { - val artemisMessagingClientFactory = { + val artemisMessagingClientFactory = { threadPoolName: String -> ArtemisMessagingClient( configuration.p2pSslOptions, serverAddress, networkParameters.maxMessageSize, - failoverCallback = { errorAndTerminate("ArtemisMessagingClient failed. Shutting down.", null) } + failoverCallback = { errorAndTerminate("ArtemisMessagingClient failed. Shutting down.", null) }, + threadPoolName = threadPoolName ) } return BridgeControlListener( @@ -431,7 +432,8 @@ open class Node(configuration: NodeConfiguration, networkParameters.maxMessageSize, configuration.crlCheckSoftFail.toRevocationConfig(), false, - artemisMessagingClientFactory) + artemisMessagingClientFactory + ) } private fun startLocalRpcBroker(securityManager: RPCSecurityManager): BrokerAddresses? { diff --git a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt index f13d1d73bf..b1e1ceb1f0 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt @@ -2,6 +2,7 @@ package net.corda.node.services.events import co.paralleluniverse.fibers.Suspendable import com.google.common.util.concurrent.ListenableFuture +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationOrigin @@ -148,7 +149,7 @@ class NodeSchedulerService(private val clock: CordaClock, // from the database private val startingStateRefs: MutableSet = ConcurrentHashMap.newKeySet() private val mutex = ThreadBox(InnerState()) - private val schedulerTimerExecutor = Executors.newSingleThreadExecutor() + private val schedulerTimerExecutor = Executors.newSingleThreadExecutor(DefaultThreadFactory("SchedulerService")) // if there's nothing to do, check every minute if something fell through the cracks. // any new state should trigger a reschedule immediately if nothing is scheduled, so I would not expect diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt index aef1c9820c..ed7a5b4cb6 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt @@ -65,7 +65,7 @@ class ArtemisMessagingServer(private val config: NodeConfiguration, private val messagingServerAddress: NetworkHostAndPort, private val maxMessageSize: Int, private val journalBufferTimeout : Int? = null, - private val threadPoolName: String = "ArtemisServer", + private val threadPoolName: String = "P2PServer", private val trace: Boolean = false, private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON, private val remotingThreads: Int? = null) : ArtemisBroker, SingletonSerializeAsToken() { diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt index ffaba4edf4..da3c2a2086 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt @@ -17,6 +17,7 @@ import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.protonwrapper.netty.createAndInitSslContext import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor +import net.corda.nodeapi.internal.setThreadPoolName import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration import org.apache.activemq.artemis.api.core.BaseInterceptor import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor @@ -54,10 +55,23 @@ class NodeNettyAcceptorFactory : AcceptorFactory { handler: BufferHandler?, listener: ServerConnectionLifeCycleListener?, threadPool: Executor, - scheduledThreadPool: ScheduledExecutorService?, + scheduledThreadPool: ScheduledExecutorService, protocolMap: Map>>?): Acceptor { + val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Acceptor", configuration) + threadPool.setThreadPoolName("$threadPoolName-artemis") + scheduledThreadPool.setThreadPoolName("$threadPoolName-artemis-scheduler") val failureExecutor = OrderedExecutor(threadPool) - return NodeNettyAcceptor(name, clusterConnection, configuration, handler, listener, scheduledThreadPool, failureExecutor, protocolMap) + return NodeNettyAcceptor( + name, + clusterConnection, + configuration, + handler, + listener, + scheduledThreadPool, + failureExecutor, + protocolMap, + "$threadPoolName-netty" + ) } @@ -68,14 +82,14 @@ class NodeNettyAcceptorFactory : AcceptorFactory { listener: ServerConnectionLifeCycleListener?, scheduledThreadPool: ScheduledExecutorService?, failureExecutor: Executor, - protocolMap: Map>>?) : + protocolMap: Map>>?, + private val threadPoolName: String) : NettyAcceptor(name, clusterConnection, configuration, handler, listener, scheduledThreadPool, failureExecutor, protocolMap) { companion object { private val defaultThreadPoolNamePattern = Pattern.compile("""Thread-(\d+) \(activemq-netty-threads\)""") } - private val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "NodeNettyAcceptor", configuration) private val sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName) private val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration) diff --git a/node/src/main/kotlin/net/corda/node/services/network/NetworkMapUpdater.kt b/node/src/main/kotlin/net/corda/node/services/network/NetworkMapUpdater.kt index fac12a9343..584b050425 100644 --- a/node/src/main/kotlin/net/corda/node/services/network/NetworkMapUpdater.kt +++ b/node/src/main/kotlin/net/corda/node/services/network/NetworkMapUpdater.kt @@ -74,7 +74,7 @@ class NetworkMapUpdater(private val networkMapCache: NetworkMapCacheInternal, } private val parametersUpdatesTrack = PublishSubject.create() - private val networkMapPoller = ScheduledThreadPoolExecutor(1, NamedThreadFactory("Network Map Updater Thread")).apply { + private val networkMapPoller = ScheduledThreadPoolExecutor(1, NamedThreadFactory("NetworkMapUpdater")).apply { executeExistingDelayedTasksAfterShutdownPolicy = false } private var newNetworkParameters: Pair? = null @@ -261,9 +261,12 @@ class NetworkMapUpdater(private val networkMapCache: NetworkMapCacheInternal, //as HTTP GET is mostly IO bound, use more threads than CPU's //maximum threads to use = 24, as if we did not limit this on large machines it could result in 100's of concurrent requests val threadsToUseForNetworkMapDownload = min(Runtime.getRuntime().availableProcessors() * 4, 24) - val executorToUseForDownloadingNodeInfos = Executors.newFixedThreadPool(threadsToUseForNetworkMapDownload, NamedThreadFactory("NetworkMapUpdaterNodeInfoDownloadThread")) + val executorToUseForDownloadingNodeInfos = Executors.newFixedThreadPool( + threadsToUseForNetworkMapDownload, + NamedThreadFactory("NetworkMapUpdaterNodeInfoDownload") + ) //DB insert is single threaded - use a single threaded executor for it. - val executorToUseForInsertionIntoDB = Executors.newSingleThreadExecutor(NamedThreadFactory("NetworkMapUpdateDBInsertThread")) + val executorToUseForInsertionIntoDB = Executors.newSingleThreadExecutor(NamedThreadFactory("NetworkMapUpdateDBInsert")) val hashesToFetch = (allHashesFromNetworkMap - allNodeHashes) val networkMapDownloadStartTime = System.currentTimeMillis() if (hashesToFetch.isNotEmpty()) { diff --git a/node/src/main/kotlin/net/corda/node/services/rpc/InternalRPCMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/rpc/InternalRPCMessagingClient.kt index 8ed025549c..e48cdf16c0 100644 --- a/node/src/main/kotlin/net/corda/node/services/rpc/InternalRPCMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/rpc/InternalRPCMessagingClient.kt @@ -22,8 +22,7 @@ class InternalRPCMessagingClient(val sslConfig: MutualSslConfiguration, val serv private var rpcServer: RPCServer? = null fun init(rpcOps: List, securityManager: RPCSecurityManager, cacheFactory: NamedCacheFactory) = synchronized(this) { - - val tcpTransport = ArtemisTcpTransport.rpcInternalClientTcpTransport(serverAddress, sslConfig) + val tcpTransport = ArtemisTcpTransport.rpcInternalClientTcpTransport(serverAddress, sslConfig, threadPoolName = "RPCClient") locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply { // Never time out on our loopback Artemis connections. If we switch back to using the InVM transport this // would be the default and the two lines below can be deleted. diff --git a/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt index e79f485c01..11ecd7e2c1 100644 --- a/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/rpc/RpcBrokerConfiguration.kt @@ -30,10 +30,10 @@ internal class RpcBrokerConfiguration(baseDirectory: Path, maxMessageSize: Int, setDirectories(baseDirectory) val acceptorConfigurationsSet = mutableSetOf( - rpcAcceptorTcpTransport(address, sslOptions, enableSSL = useSsl) + rpcAcceptorTcpTransport(address, sslOptions, enableSSL = useSsl, threadPoolName = "RPCServer") ) adminAddress?.let { - acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration) + acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration, threadPoolName = "RPCServerAdmin") } acceptorConfigurations = acceptorConfigurationsSet diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt index c08515ab0e..734b2b8234 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMonitor.kt @@ -1,5 +1,6 @@ package net.corda.node.services.statemachine +import io.netty.util.concurrent.DefaultThreadFactory import net.corda.core.flows.FlowSession import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowStateMachine @@ -22,10 +23,6 @@ internal class FlowMonitor( ) : LifecycleSupport { private companion object { - private fun defaultScheduler(): ScheduledExecutorService { - return Executors.newSingleThreadScheduledExecutor() - } - private val logger = loggerFor() } @@ -36,7 +33,7 @@ internal class FlowMonitor( override fun start() { synchronized(this) { if (scheduler == null) { - scheduler = defaultScheduler() + scheduler = Executors.newSingleThreadScheduledExecutor(DefaultThreadFactory("FlowMonitor")) shutdownScheduler = true } scheduler!!.scheduleAtFixedRate({ logFlowsWaitingForParty() }, 0, monitoringPeriod.toMillis(), TimeUnit.MILLISECONDS) From 1e0f7cd69090c566273385ed9a6bf8a5074305db Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Fri, 7 Jul 2023 15:29:44 +0100 Subject: [PATCH 12/12] Merge fix --- .../corda/node/services/messaging/NodeNettyAcceptorFactory.kt | 1 - 1 file changed, 1 deletion(-) diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt index bc37cd747e..9e9c7ca081 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeNettyAcceptorFactory.kt @@ -12,7 +12,6 @@ import net.corda.core.utilities.contextLogger import net.corda.nodeapi.internal.ArtemisTcpTransport import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor import net.corda.nodeapi.internal.setThreadPoolName -import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration import org.apache.activemq.artemis.api.core.BaseInterceptor import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor import org.apache.activemq.artemis.core.server.balancing.RedirectHandler