Merge pull request #7381 from corda/shams-4.9-frwd-merge-a817218b

ENT-9806: 4.8 to 4.9 forward merge
This commit is contained in:
Rick Parker 2023-06-05 09:33:24 +01:00 committed by GitHub
commit d0f28a607f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 1246 additions and 951 deletions

View File

@ -0,0 +1,29 @@
package net.corda.coretests.crypto.internal
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.testing.core.createCRL
import org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import org.junit.Test
class ProviderMapTest {
// https://github.com/corda/corda/pull/3997
@Test(timeout = 300_000)
fun `verify CRL algorithms`() {
val crl = createCRL(
issuer = DEV_ROOT_CA,
revokedCerts = emptyList(),
signatureAlgorithm = "SHA256withECDSA"
)
// This should pass.
crl.verify(DEV_ROOT_CA.keyPair.public)
// Try changing the algorithm to EC will fail.
assertThatIllegalArgumentException().isThrownBy {
createCRL(
issuer = DEV_ROOT_CA,
revokedCerts = emptyList(),
signatureAlgorithm = "EC"
)
}.withMessage("Unknown signature type requested: EC")
}
}

View File

@ -64,7 +64,7 @@ interface ServicesForResolution {
/** /**
* Given a [Set] of [StateRef]'s loads the referenced transaction and looks up the specified output [ContractState]. * Given a [Set] of [StateRef]'s loads the referenced transaction and looks up the specified output [ContractState].
* *
* @throws TransactionResolutionException if [stateRef] points to a non-existent transaction. * @throws TransactionResolutionException if any of the [stateRefs] point to a non-existent transaction.
*/ */
// TODO: future implementation to use a Vault state ref -> contract state BLOB table and perform single query bulk load // TODO: future implementation to use a Vault state ref -> contract state BLOB table and perform single query bulk load
// as the existing transaction store will become encrypted at some point // as the existing transaction store will become encrypted at some point

View File

@ -21,14 +21,29 @@ import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.core.utilities.hours import net.corda.core.utilities.hours
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme import net.corda.coretesting.internal.NettyTestClient
import net.corda.coretesting.internal.NettyTestHandler
import net.corda.coretesting.internal.NettyTestServer
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.createDevNodeCa import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_IDENTITY_SIGNATURE_SCHEME import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_IDENTITY_SIGNATURE_SCHEME
import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME
import net.corda.nodeapi.internal.crypto.checkValidity
import net.corda.nodeapi.internal.crypto.getSupportedKey
import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore
import net.corda.nodeapi.internal.crypto.save
import net.corda.nodeapi.internal.crypto.toBc
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.crypto.x509Certificates
import net.corda.nodeapi.internal.installDevNodeCaCertPath import net.corda.nodeapi.internal.installDevNodeCaCertPath
import net.corda.nodeapi.internal.protonwrapper.netty.init import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
import net.corda.serialization.internal.AllWhitelist import net.corda.serialization.internal.AllWhitelist
import net.corda.serialization.internal.SerializationContextImpl import net.corda.serialization.internal.SerializationContextImpl
import net.corda.serialization.internal.SerializationFactoryImpl import net.corda.serialization.internal.SerializationFactoryImpl
@ -37,25 +52,16 @@ import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.TestIdentity import net.corda.testing.core.TestIdentity
import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.coretesting.internal.NettyTestClient
import net.corda.coretesting.internal.NettyTestHandler
import net.corda.coretesting.internal.NettyTestServer
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.checkValidity
import net.corda.nodeapi.internal.crypto.getSupportedKey
import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore
import net.corda.nodeapi.internal.crypto.save
import net.corda.nodeapi.internal.crypto.toBc
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.crypto.x509Certificates
import net.corda.testing.internal.IS_OPENJ9 import net.corda.testing.internal.IS_OPENJ9
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.i2p.crypto.eddsa.EdDSAPrivateKey import net.i2p.crypto.eddsa.EdDSAPrivateKey
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.asn1.x509.* import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.bouncycastle.jcajce.provider.asymmetric.edec.BCEdDSAPrivateKey import org.bouncycastle.jcajce.provider.asymmetric.edec.BCEdDSAPrivateKey
import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey
import org.junit.Assume import org.junit.Assume
@ -74,10 +80,19 @@ import java.security.PrivateKey
import java.security.cert.CertPath import java.security.cert.CertPath
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.* import java.util.*
import javax.net.ssl.* import javax.net.ssl.SSLContext
import javax.net.ssl.SSLParameters
import javax.net.ssl.SSLServerSocket
import javax.net.ssl.SSLSocket
import javax.security.auth.x500.X500Principal import javax.security.auth.x500.X500Principal
import kotlin.concurrent.thread import kotlin.concurrent.thread
import kotlin.test.* import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue
import kotlin.test.fail
class X509UtilitiesTest { class X509UtilitiesTest {
private companion object { private companion object {
@ -295,15 +310,10 @@ class X509UtilitiesTest {
sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate) sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val context = SSLContext.getInstance("TLS") val context = SSLContext.getInstance("TLS")
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
keyManagerFactory.init(keyStore)
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get())
trustMgrFactory.init(trustStore)
val trustManagers = trustMgrFactory.trustManagers val trustManagers = trustMgrFactory.trustManagers
context.init(keyManagers, trustManagers, newSecureRandom()) context.init(keyManagers, trustManagers, newSecureRandom())
@ -388,15 +398,8 @@ class X509UtilitiesTest {
sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate) sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get() val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val trustStore = sslConfig.trustStore.get() val trustManagerFactory = trustManagerFactory(sslConfig.trustStore.get())
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore)
val sslServerContext = SslContextBuilder val sslServerContext = SslContextBuilder
.forServer(keyManagerFactory) .forServer(keyManagerFactory)

View File

@ -1,16 +1,18 @@
@file:Suppress("LongParameterList")
package net.corda.nodeapi.internal package net.corda.nodeapi.internal
import net.corda.core.messaging.ClientRpcSslOptions import net.corda.core.messaging.ClientRpcSslOptions
import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.BrokerRpcSslOptions import net.corda.nodeapi.BrokerRpcSslOptions
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT
import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.config.SslConfiguration import net.corda.nodeapi.internal.config.SslConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.apache.activemq.artemis.api.core.TransportConfiguration import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants
import java.nio.file.Path import javax.net.ssl.TrustManagerFactory
@Suppress("LongParameterList") @Suppress("LongParameterList")
class ArtemisTcpTransport { class ArtemisTcpTransport {
@ -23,6 +25,7 @@ class ArtemisTcpTransport {
val TLS_VERSIONS = listOf("TLSv1.2") val TLS_VERSIONS = listOf("TLSv1.2")
const val SSL_HANDSHAKE_TIMEOUT_NAME = "Corda-SSLHandshakeTimeout" const val SSL_HANDSHAKE_TIMEOUT_NAME = "Corda-SSLHandshakeTimeout"
const val TRUST_MANAGER_FACTORY_NAME = "Corda-TrustManagerFactory"
const val TRACE_NAME = "Corda-Trace" const val TRACE_NAME = "Corda-Trace"
const val THREAD_POOL_NAME_NAME = "Corda-ThreadPoolName" const val THREAD_POOL_NAME_NAME = "Corda-ThreadPoolName"
@ -30,7 +33,6 @@ class ArtemisTcpTransport {
// Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop. // Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop.
// It does not use AMQP messages for its own messages e.g. topology and heartbeats. // It does not use AMQP messages for its own messages e.g. topology and heartbeats.
private const val P2P_PROTOCOLS = "CORE,AMQP" private const val P2P_PROTOCOLS = "CORE,AMQP"
private const val RPC_PROTOCOLS = "CORE" private const val RPC_PROTOCOLS = "CORE"
private fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort, protocols: String) = mapOf( private fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort, protocols: String) = mapOf(
@ -39,46 +41,35 @@ class ArtemisTcpTransport {
TransportConstants.PORT_PROP_NAME to hostAndPort.port, TransportConstants.PORT_PROP_NAME to hostAndPort.port,
TransportConstants.PROTOCOLS_PROP_NAME to protocols, TransportConstants.PROTOCOLS_PROP_NAME to protocols,
TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null), TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null),
TransportConstants.REMOTING_THREADS_PROPNAME to (if (nodeSerializationEnv != null) -1 else 1),
// turn off direct delivery in Artemis - this is latency optimisation that can lead to // turn off direct delivery in Artemis - this is latency optimisation that can lead to
//hick-ups under high load (CORDA-1336) //hick-ups under high load (CORDA-1336)
TransportConstants.DIRECT_DELIVER to false) TransportConstants.DIRECT_DELIVER to false)
private val defaultSSLOptions = mapOf(
TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","),
TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to TLS_VERSIONS.joinToString(","))
private fun SslConfiguration.addToTransportOptions(options: MutableMap<String, Any>) { private fun SslConfiguration.addToTransportOptions(options: MutableMap<String, Any>) {
if (keyStore != null || trustStore != null) {
options[TransportConstants.SSL_ENABLED_PROP_NAME] = true
options[TransportConstants.NEED_CLIENT_AUTH_PROP_NAME] = true
}
keyStore?.let { keyStore?.let {
with (it) { with (it) {
path.requireOnDefaultFileSystem() path.requireOnDefaultFileSystem()
options.putAll(get().toKeyStoreTransportOptions(path)) options[TransportConstants.KEYSTORE_TYPE_PROP_NAME] = "JKS"
options[TransportConstants.KEYSTORE_PATH_PROP_NAME] = path
options[TransportConstants.KEYSTORE_PASSWORD_PROP_NAME] = get().password
} }
} }
trustStore?.let { trustStore?.let {
with (it) { with (it) {
path.requireOnDefaultFileSystem() path.requireOnDefaultFileSystem()
options.putAll(get().toTrustStoreTransportOptions(path)) options[TransportConstants.TRUSTSTORE_TYPE_PROP_NAME] = "JKS"
options[TransportConstants.TRUSTSTORE_PATH_PROP_NAME] = path
options[TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME] = get().password
} }
} }
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
options[SSL_HANDSHAKE_TIMEOUT_NAME] = handshakeTimeout ?: DEFAULT_SSL_HANDSHAKE_TIMEOUT options[SSL_HANDSHAKE_TIMEOUT_NAME] = handshakeTimeout ?: DEFAULT_SSL_HANDSHAKE_TIMEOUT
} }
private fun CertificateStore.toKeyStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.KEYSTORE_TYPE_PROP_NAME to "JKS",
TransportConstants.KEYSTORE_PATH_PROP_NAME to path,
TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun CertificateStore.toTrustStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to "JKS",
TransportConstants.TRUSTSTORE_PATH_PROP_NAME to path,
TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun ClientRpcSslOptions.toTransportOptions() = mapOf( private fun ClientRpcSslOptions.toTransportOptions() = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true, TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to trustStoreProvider, TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to trustStoreProvider,
@ -94,50 +85,64 @@ class ArtemisTcpTransport {
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: MutualSslConfiguration?, config: MutualSslConfiguration?,
trustManagerFactory: TrustManagerFactory? = config?.trustStore?.get()?.let(::trustManagerFactory),
enableSSL: Boolean = true, enableSSL: Boolean = true,
threadPoolName: String = "P2PServer", threadPoolName: String = "P2PServer",
trace: Boolean = false): TransportConfiguration { trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>() val options = mutableMapOf<String, Any>()
if (enableSSL) { if (enableSSL) {
config?.addToTransportOptions(options) config?.addToTransportOptions(options)
} }
return createAcceptorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace) return createAcceptorTransport(
hostAndPort,
P2P_PROTOCOLS,
options,
trustManagerFactory,
enableSSL,
threadPoolName,
trace,
remotingThreads
)
} }
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
config: MutualSslConfiguration?, config: MutualSslConfiguration?,
enableSSL: Boolean = true, enableSSL: Boolean = true,
threadPoolName: String = "P2PClient", threadPoolName: String = "P2PClient",
trace: Boolean = false): TransportConfiguration { trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>() val options = mutableMapOf<String, Any>()
if (enableSSL) { if (enableSSL) {
config?.addToTransportOptions(options) config?.addToTransportOptions(options)
} }
return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace) return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace, remotingThreads)
} }
fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: BrokerRpcSslOptions?, config: BrokerRpcSslOptions?,
enableSSL: Boolean = true, enableSSL: Boolean = true,
trace: Boolean = false): TransportConfiguration { trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>() val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) { if (config != null && enableSSL) {
config.keyStorePath.requireOnDefaultFileSystem() config.keyStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions()) options.putAll(config.toTransportOptions())
} }
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCServer", trace) return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, null, enableSSL, "RPCServer", trace, remotingThreads)
} }
fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort, fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
config: ClientRpcSslOptions?, config: ClientRpcSslOptions?,
enableSSL: Boolean = true, enableSSL: Boolean = true,
trace: Boolean = false): TransportConfiguration { trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>() val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) { if (config != null && enableSSL) {
config.trustStorePath.requireOnDefaultFileSystem() config.trustStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions()) options.putAll(config.toTransportOptions())
} }
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace) return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace, remotingThreads)
} }
fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort, fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort,
@ -145,25 +150,42 @@ class ArtemisTcpTransport {
trace: Boolean = false): TransportConfiguration { trace: Boolean = false): TransportConfiguration {
val options = mutableMapOf<String, Any>() val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options) config.addToTransportOptions(options)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace) return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace, null)
} }
fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: SslConfiguration, config: SslConfiguration,
trace: Boolean = false): TransportConfiguration { trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>() val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options) config.addToTransportOptions(options)
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCServer", trace) return createAcceptorTransport(
hostAndPort,
RPC_PROTOCOLS,
options,
trustManagerFactory(requireNotNull(config.trustStore).get()),
true,
"Internal-RPCServer",
trace,
remotingThreads
)
} }
private fun createAcceptorTransport(hostAndPort: NetworkHostAndPort, private fun createAcceptorTransport(hostAndPort: NetworkHostAndPort,
protocols: String, protocols: String,
options: MutableMap<String, Any>, options: MutableMap<String, Any>,
trustManagerFactory: TrustManagerFactory?,
enableSSL: Boolean, enableSSL: Boolean,
threadPoolName: String, threadPoolName: String,
trace: Boolean): TransportConfiguration { trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
// Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections // Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections
options[TransportConstants.HANDSHAKE_TIMEOUT] = 0 options[TransportConstants.HANDSHAKE_TIMEOUT] = 0
if (trustManagerFactory != null) {
// NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use
// more customised instances which use our revocation checkers, so we pass them in, to be picked up by Node(Open)SSLContextFactory.
options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory
}
return createTransport( return createTransport(
"net.corda.node.services.messaging.NodeNettyAcceptorFactory", "net.corda.node.services.messaging.NodeNettyAcceptorFactory",
hostAndPort, hostAndPort,
@ -171,7 +193,8 @@ class ArtemisTcpTransport {
options, options,
enableSSL, enableSSL,
threadPoolName, threadPoolName,
trace trace,
remotingThreads
) )
} }
@ -180,7 +203,12 @@ class ArtemisTcpTransport {
options: MutableMap<String, Any>, options: MutableMap<String, Any>,
enableSSL: Boolean, enableSSL: Boolean,
threadPoolName: String, threadPoolName: String,
trace: Boolean): TransportConfiguration { trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
if (enableSSL) {
// This is required to stop Client checking URL address vs. Server provided certificate
options[TransportConstants.VERIFY_HOST_PROP_NAME] = false
}
return createTransport( return createTransport(
NodeNettyConnectorFactory::class.java.name, NodeNettyConnectorFactory::class.java.name,
hostAndPort, hostAndPort,
@ -188,7 +216,8 @@ class ArtemisTcpTransport {
options, options,
enableSSL, enableSSL,
threadPoolName, threadPoolName,
trace trace,
remotingThreads
) )
} }
@ -198,13 +227,15 @@ class ArtemisTcpTransport {
options: MutableMap<String, Any>, options: MutableMap<String, Any>,
enableSSL: Boolean, enableSSL: Boolean,
threadPoolName: String, threadPoolName: String,
trace: Boolean): TransportConfiguration { trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
options += defaultArtemisOptions(hostAndPort, protocols) options += defaultArtemisOptions(hostAndPort, protocols)
if (enableSSL) { if (enableSSL) {
options += defaultSSLOptions options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",")
// This is required to stop Client checking URL address vs. Server provided certificate options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",")
options[TransportConstants.VERIFY_HOST_PROP_NAME] = false
} }
// By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357)
options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1
options[THREAD_POOL_NAME_NAME] = threadPoolName options[THREAD_POOL_NAME_NAME] = threadPoolName
options[TRACE_NAME] = trace options[TRACE_NAME] = trace
return TransportConfiguration(className, options) return TransportConfiguration(className, options)

View File

@ -0,0 +1,32 @@
@file:Suppress("LongParameterList", "MagicNumber")
package net.corda.nodeapi.internal
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.utilities.seconds
import java.time.Duration
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
/**
* Creates a [ThreadPoolExecutor] which will use a maximum of [maxPoolSize] threads at any given time and will by default idle down to 0
* threads.
*/
fun namedThreadPoolExecutor(maxPoolSize: Int,
corePoolSize: Int = 0,
idleKeepAlive: Duration = 30.seconds,
workQueue: BlockingQueue<Runnable> = LinkedBlockingQueue(),
poolName: String = "pool",
daemonThreads: Boolean = false,
threadPriority: Int = Thread.NORM_PRIORITY): ThreadPoolExecutor {
return ThreadPoolExecutor(
corePoolSize,
maxPoolSize,
idleKeepAlive.toNanos(),
TimeUnit.NANOSECONDS,
workQueue,
DefaultThreadFactory(poolName, daemonThreads, threadPriority)
)
}

View File

@ -100,7 +100,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private class AMQPBridge(val sourceX500Name: String, private class AMQPBridge(val sourceX500Name: String,
val queueName: String, val queueName: String,
val targets: List<NetworkHostAndPort>, val targets: List<NetworkHostAndPort>,
val legalNames: Set<CordaX500Name>, val allowedRemoteLegalNames: Set<CordaX500Name>,
private val amqpConfig: AMQPConfiguration, private val amqpConfig: AMQPConfiguration,
sharedEventGroup: EventLoopGroup, sharedEventGroup: EventLoopGroup,
private val artemis: ArtemisSessionProvider, private val artemis: ArtemisSessionProvider,
@ -116,7 +116,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
MDC.put("queueName", queueName) MDC.put("queueName", queueName)
MDC.put("source", amqpConfig.sourceX500Name) MDC.put("source", amqpConfig.sourceX500Name)
MDC.put("targets", targets.joinToString(separator = ";") { it.toString() }) MDC.put("targets", targets.joinToString(separator = ";") { it.toString() })
MDC.put("legalNames", legalNames.joinToString(separator = ";") { it.toString() }) MDC.put("allowedRemoteLegalNames", allowedRemoteLegalNames.joinToString(separator = ";") { it.toString() })
MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString()) MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString())
block() block()
} finally { } finally {
@ -134,7 +134,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) } private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) }
val amqpClient = AMQPClient(targets, legalNames, amqpConfig, sharedThreadPool = sharedEventGroup) val amqpClient = AMQPClient(targets, allowedRemoteLegalNames, amqpConfig, sharedThreadPool = sharedEventGroup)
private var session: ClientSession? = null private var session: ClientSession? = null
private var consumer: ClientConsumer? = null private var consumer: ClientConsumer? = null
private var connectedSubscription: Subscription? = null private var connectedSubscription: Subscription? = null
@ -231,7 +231,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
} }
ArtemisState.STOPPING ArtemisState.STOPPING
} }
bridgeMetricsService?.bridgeDisconnected(targets, legalNames) bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
connectedSubscription?.unsubscribe() connectedSubscription?.unsubscribe()
connectedSubscription = null connectedSubscription = null
// Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first. // Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first.
@ -243,7 +243,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
if (connected) { if (connected) {
logInfoWithMDC("Bridge Connected") logInfoWithMDC("Bridge Connected")
bridgeMetricsService?.bridgeConnected(targets, legalNames) bridgeMetricsService?.bridgeConnected(targets, allowedRemoteLegalNames)
if (bridgeConnectionTTLSeconds > 0) { if (bridgeConnectionTTLSeconds > 0) {
// AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval // AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval
amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS, amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS,
@ -286,7 +286,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
logInfoWithMDC("Bridge Disconnected") logInfoWithMDC("Bridge Disconnected")
amqpRestartEvent?.cancel(false) amqpRestartEvent?.cancel(false)
if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) { if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) {
bridgeMetricsService?.bridgeDisconnected(targets, legalNames) bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
} }
artemis(ArtemisState.STOPPING) { precedingState: ArtemisState -> artemis(ArtemisState.STOPPING) { precedingState: ArtemisState ->
logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected") logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected")
@ -418,10 +418,10 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
properties[key] = value properties[key] = value
} }
} }
logDebugWithMDC { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" } logDebugWithMDC { "Bridged Send to ${allowedRemoteLegalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
val peerInbox = translateLocalQueueToInboxAddress(queueName) val peerInbox = translateLocalQueueToInboxAddress(queueName)
val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox, val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox,
legalNames.first().toString(), allowedRemoteLegalNames.first().toString(),
properties) properties)
sendableMessage.onComplete.then { sendableMessage.onComplete.then {
logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" } logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" }
@ -486,7 +486,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
queueNamesToBridgesMap.remove(queueName) queueNamesToBridgesMap.remove(queueName)
} }
bridge.stop() bridge.stop()
bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.legalNames) bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.allowedRemoteLegalNames)
} }
} }
} }
@ -498,7 +498,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
val bridges = queueNamesToBridgesMap[queueName]?.toList() val bridges = queueNamesToBridgesMap[queueName]?.toList()
destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList()) destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList())
bridges?.map { bridges?.map {
it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.legalNames.toList(), serviceAddress = false) it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.allowedRemoteLegalNames.toList(), serviceAddress = false)
}?.toMap() ?: emptyMap() }?.toMap() ?: emptyMap()
} }
} }

View File

@ -5,16 +5,37 @@ package net.corda.nodeapi.internal.crypto
import net.corda.core.CordaOID import net.corda.core.CordaOID
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.crypto.newSecureRandom import net.corda.core.crypto.newSecureRandom
import net.corda.core.internal.* import net.corda.core.internal.CertRole
import net.corda.core.internal.SignedDataWithCert
import net.corda.core.internal.reader
import net.corda.core.internal.signWithCert
import net.corda.core.internal.uncheckedCast
import net.corda.core.internal.validate
import net.corda.core.internal.writer
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.core.utilities.millis import net.corda.core.utilities.millis
import net.corda.core.utilities.toHex import net.corda.core.utilities.toHex
import net.corda.nodeapi.internal.protonwrapper.netty.distributionPointsToString import net.corda.nodeapi.internal.protonwrapper.netty.distributionPointsToString
import org.bouncycastle.asn1.* import org.bouncycastle.asn1.ASN1EncodableVector
import org.bouncycastle.asn1.ASN1ObjectIdentifier
import org.bouncycastle.asn1.ASN1Sequence
import org.bouncycastle.asn1.DERSequence
import org.bouncycastle.asn1.DERUTF8String
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.style.BCStyle import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x509.* import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.DistributionPoint
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.KeyPurposeId
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.asn1.x509.NameConstraints
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import org.bouncycastle.cert.X509CertificateHolder import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.X509v3CertificateBuilder import org.bouncycastle.cert.X509v3CertificateBuilder
import org.bouncycastle.cert.bc.BcX509ExtensionUtils import org.bouncycastle.cert.bc.BcX509ExtensionUtils
@ -32,8 +53,13 @@ import java.nio.file.Path
import java.security.KeyPair import java.security.KeyPair
import java.security.PublicKey import java.security.PublicKey
import java.security.SignatureException import java.security.SignatureException
import java.security.cert.* import java.security.cert.CertPath
import java.security.cert.Certificate import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.CertificateFactory
import java.security.cert.TrustAnchor
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.time.Duration import java.time.Duration
import java.time.Instant import java.time.Instant
import java.time.temporal.ChronoUnit import java.time.temporal.ChronoUnit
@ -359,7 +385,7 @@ object X509Utilities {
private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) { private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) {
if (crlDistPoint != null) { if (crlDistPoint != null) {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint))) val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier))
val crlIssuerGeneralNames = crlIssuer?.let { val crlIssuerGeneralNames = crlIssuer?.let {
GeneralNames(GeneralName(crlIssuer)) GeneralNames(GeneralName(crlIssuer))
} }
@ -379,6 +405,8 @@ object X509Utilities {
bytes[0] = bytes[0].and(0x3F).or(0x40) bytes[0] = bytes[0].and(0x3F).or(0x40)
return BigInteger(bytes) return BigInteger(bytes)
} }
fun toGeneralNames(string: String, tag: Int = GeneralName.directoryName): GeneralNames = GeneralNames(GeneralName(tag, string))
} }
// Assuming cert type to role is 1:1 // Assuming cert type to role is 1:1

View File

@ -27,15 +27,14 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME
import net.corda.nodeapi.internal.requireMessageSize import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.lang.Long.min import java.lang.Long.min
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.security.cert.CertPathValidatorException import java.util.concurrent.ExecutorService
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
enum class ProxyVersion { enum class ProxyVersion {
@ -63,7 +62,8 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
val allowedRemoteLegalNames: Set<CordaX500Name>, val allowedRemoteLegalNames: Set<CordaX500Name>,
private val configuration: AMQPConfiguration, private val configuration: AMQPConfiguration,
private val sharedThreadPool: EventLoopGroup? = null, private val sharedThreadPool: EventLoopGroup? = null,
private val threadPoolName: String = "AMQPClient") : AutoCloseable { private val threadPoolName: String = "AMQPClient",
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON) : AutoCloseable {
companion object { companion object {
init { init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
@ -89,12 +89,12 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
private var targetIndex = 0 private var targetIndex = 0
private var currentTarget: NetworkHostAndPort = targets.first() private var currentTarget: NetworkHostAndPort = targets.first()
private var retryInterval = MIN_RETRY_INTERVAL private var retryInterval = MIN_RETRY_INTERVAL
private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker()
private val badCertTargets = mutableSetOf<NetworkHostAndPort>() private val badCertTargets = mutableSetOf<NetworkHostAndPort>()
@Volatile @Volatile
private var amqpActive = false private var amqpActive = false
@Volatile @Volatile
private var amqpChannelHandler: ChannelHandler? = null private var amqpChannelHandler: ChannelHandler? = null
private var sslDelegatedTaskExecutor: ExecutorService? = null
val localAddressString: String val localAddressString: String
get() = clientChannel?.localAddress()?.toString() ?: "<unknownLocalAddress>" get() = clientChannel?.localAddress()?.toString() ?: "<unknownLocalAddress>"
@ -150,17 +150,16 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
} }
private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer<SocketChannel>() { private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration private val conf = parent.configuration
@Volatile @Volatile
private lateinit var amqpChannelHandler: AMQPChannelHandler private lateinit var amqpChannelHandler: AMQPChannelHandler
init {
keyManagerFactory.init(conf.keyStore)
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker))
}
@Suppress("ComplexMethod") @Suppress("ComplexMethod")
override fun initChannel(ch: SocketChannel) { override fun initChannel(ch: SocketChannel) {
val pipeline = ch.pipeline() val pipeline = ch.pipeline()
@ -199,10 +198,24 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration) val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration)
val target = parent.currentTarget val target = parent.currentTarget
val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor)
val handler = if (parent.configuration.useOpenSsl) { val handler = if (parent.configuration.useOpenSsl) {
createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc()) createClientOpenSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
ch.alloc(),
delegatedTaskExecutor
)
} else { } else {
createClientSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory) createClientSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
delegatedTaskExecutor
)
} }
handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis() handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis()
pipeline.addLast("sslHandler", handler) pipeline.addLast("sslHandler", handler)
@ -260,6 +273,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
return return
} }
log.info("Connect to: $currentTarget") log.info("Connect to: $currentTarget")
sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY)) workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY))
started = true started = true
restart() restart()
@ -294,6 +308,8 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
} }
clientChannel = null clientChannel = null
workerGroup = null workerGroup = null
sslDelegatedTaskExecutor?.shutdown()
sslDelegatedTaskExecutor = null
log.info("Stopped connection to $currentTarget") log.info("Stopped connection to $currentTarget")
} }
} }
@ -334,6 +350,4 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized() private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized()
val onConnection: Observable<ConnectionChange> val onConnection: Observable<ConnectionChange>
get() = _onConnection get() = _onConnection
}
val softFailExceptions: List<CertPathValidatorException> get() = revocationChecker.softFailExceptions
}

View File

@ -21,16 +21,15 @@ import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.requireMessageSize import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import org.apache.qpid.proton.engine.Delivery import org.apache.qpid.proton.engine.Delivery
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.net.BindException import java.net.BindException
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.security.cert.CertPathValidatorException
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
/** /**
@ -39,37 +38,34 @@ import kotlin.concurrent.withLock
class AMQPServer(val hostName: String, class AMQPServer(val hostName: String,
val port: Int, val port: Int,
private val configuration: AMQPConfiguration, private val configuration: AMQPConfiguration,
private val threadPoolName: String = "AMQPServer") : AutoCloseable { private val threadPoolName: String = "AMQPServer",
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON,
private val remotingThreads: Int? = null) : AutoCloseable {
companion object { companion object {
init { init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
} }
private const val CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME = "net.corda.nodeapi.amqpserver.NumServerThreads"
private val log = contextLogger() private val log = contextLogger()
private val NUM_SERVER_THREADS = Integer.getInteger(CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME, 4) private val DEFAULT_REMOTING_THREADS = Integer.getInteger("net.corda.nodeapi.amqpserver.NumServerThreads", 4)
} }
private val lock = ReentrantLock() private val lock = ReentrantLock()
@Volatile
private var stopping: Boolean = false
private var bossGroup: EventLoopGroup? = null private var bossGroup: EventLoopGroup? = null
private var workerGroup: EventLoopGroup? = null private var workerGroup: EventLoopGroup? = null
private var serverChannel: Channel? = null private var serverChannel: Channel? = null
private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker() private var sslDelegatedTaskExecutor: ExecutorService? = null
private val clientChannels = ConcurrentHashMap<InetSocketAddress, SocketChannel>() private val clientChannels = ConcurrentHashMap<InetSocketAddress, SocketChannel>()
private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer<SocketChannel>() { private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration private val conf = parent.configuration
init {
keyManagerFactory.init(conf.keyStore.value.internal, conf.keyStore.entryPassword.toCharArray())
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker))
}
override fun initChannel(ch: SocketChannel) { override fun initChannel(ch: SocketChannel) {
val amqpConfiguration = parent.configuration val amqpConfiguration = parent.configuration
val pipeline = ch.pipeline() val pipeline = ch.pipeline()
@ -116,11 +112,12 @@ class AMQPServer(val hostName: String,
Pair(createServerSNIOpenSniHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap) Pair(createServerSNIOpenSniHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap)
} else { } else {
val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig) val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig)
val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor)
val handler = if (amqpConfig.useOpenSsl) { val handler = if (amqpConfig.useOpenSsl) {
createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc()) createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc(), delegatedTaskExecutor)
} else { } else {
// For javaSSL, SNI matching is handled at key manager level. // For javaSSL, SNI matching is handled at key manager level.
createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory) createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory, delegatedTaskExecutor)
} }
handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout.toMillis() handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout.toMillis()
Pair(handler, mapOf(DEFAULT to keyManagerFactory)) Pair(handler, mapOf(DEFAULT to keyManagerFactory))
@ -132,8 +129,13 @@ class AMQPServer(val hostName: String,
lock.withLock { lock.withLock {
stop() stop()
sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("$threadPoolName-boss", Thread.MAX_PRIORITY)) bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("$threadPoolName-boss", Thread.MAX_PRIORITY))
workerGroup = NioEventLoopGroup(NUM_SERVER_THREADS, DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY)) workerGroup = NioEventLoopGroup(
remotingThreads ?: DEFAULT_REMOTING_THREADS,
DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY)
)
val server = ServerBootstrap() val server = ServerBootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux // TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
@ -154,22 +156,19 @@ class AMQPServer(val hostName: String,
fun stop() { fun stop() {
lock.withLock { lock.withLock {
try { serverChannel?.close()
stopping = true serverChannel = null
serverChannel?.apply { close() }
serverChannel = null
workerGroup?.shutdownGracefully() workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync() workerGroup?.terminationFuture()?.sync()
workerGroup = null
bossGroup?.shutdownGracefully() bossGroup?.shutdownGracefully()
bossGroup?.terminationFuture()?.sync() bossGroup?.terminationFuture()?.sync()
bossGroup = null
workerGroup = null sslDelegatedTaskExecutor?.shutdown()
bossGroup = null sslDelegatedTaskExecutor = null
} finally {
stopping = false
}
} }
} }
@ -226,6 +225,4 @@ class AMQPServer(val hostName: String,
private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized() private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized()
val onConnection: Observable<ConnectionChange> val onConnection: Observable<ConnectionChange>
get() = _onConnection get() = _onConnection
val softFailExceptions: List<CertPathValidatorException> get() = revocationChecker.softFailExceptions
} }

View File

@ -31,4 +31,6 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() {
override fun getSoftFailExceptions(): List<CertPathValidatorException> { override fun getSoftFailExceptions(): List<CertPathValidatorException> {
return Collections.emptyList() return Collections.emptyList()
} }
override fun clone(): AllowAllRevocationChecker = this
} }

View File

@ -3,9 +3,6 @@ package net.corda.nodeapi.internal.protonwrapper.netty
import com.typesafe.config.Config import com.typesafe.config.Config
import net.corda.nodeapi.internal.config.ConfigParser import net.corda.nodeapi.internal.config.ConfigParser
import net.corda.nodeapi.internal.config.CustomConfigParser import net.corda.nodeapi.internal.config.CustomConfigParser
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import java.security.cert.PKIXRevocationChecker
/** /**
* Data structure for controlling the way how Certificate Revocation Lists are handled. * Data structure for controlling the way how Certificate Revocation Lists are handled.
@ -45,18 +42,6 @@ interface RevocationConfig {
* Optional [CrlSource] which only makes sense with `mode` = `EXTERNAL_SOURCE` * Optional [CrlSource] which only makes sense with `mode` = `EXTERNAL_SOURCE`
*/ */
val externalCrlSource: CrlSource? val externalCrlSource: CrlSource?
fun createPKIXRevocationChecker(): PKIXRevocationChecker {
return when (mode) {
Mode.OFF -> AllowAllRevocationChecker
Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(externalCrlSource) { "externalCrlSource must be specfied for EXTERNAL_SOURCE" }
CordaRevocationChecker(externalCrlSource, softFail = true)
}
Mode.SOFT_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = true)
Mode.HARD_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = false)
}
}
} }
/** /**

View File

@ -1,3 +1,5 @@
@file:Suppress("ComplexMethod", "LongParameterList")
package net.corda.nodeapi.internal.protonwrapper.netty package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.buffer.ByteBufAllocator import io.netty.buffer.ByteBufAllocator
@ -18,6 +20,8 @@ import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.toSimpleString import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.crypto.x509 import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.namedThreadPoolExecutor
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import org.bouncycastle.asn1.ASN1InputStream import org.bouncycastle.asn1.ASN1InputStream
import org.bouncycastle.asn1.ASN1Primitive import org.bouncycastle.asn1.ASN1Primitive
import org.bouncycastle.asn1.DERIA5String import org.bouncycastle.asn1.DERIA5String
@ -34,10 +38,10 @@ import java.net.URI
import java.security.KeyStore import java.security.KeyStore
import java.security.cert.CertificateException import java.security.cert.CertificateException
import java.security.cert.PKIXBuilderParameters import java.security.cert.PKIXBuilderParameters
import java.security.cert.PKIXRevocationChecker
import java.security.cert.X509CertSelector import java.security.cert.X509CertSelector
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.concurrent.Executor import java.util.concurrent.Executor
import java.util.concurrent.ThreadPoolExecutor
import javax.net.ssl.CertPathTrustManagerParameters import javax.net.ssl.CertPathTrustManagerParameters
import javax.net.ssl.KeyManagerFactory import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName import javax.net.ssl.SNIHostName
@ -46,7 +50,6 @@ import javax.net.ssl.SSLEngine
import javax.net.ssl.TrustManagerFactory import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509ExtendedTrustManager import javax.net.ssl.X509ExtendedTrustManager
import javax.security.auth.x500.X500Principal import javax.security.auth.x500.X500Principal
import kotlin.system.measureTimeMillis
private const val HOSTNAME_FORMAT = "%s.corda.net" private const val HOSTNAME_FORMAT = "%s.corda.net"
internal const val DEFAULT = "default" internal const val DEFAULT = "default"
@ -58,7 +61,6 @@ internal val logger = LoggerFactory.getLogger("net.corda.nodeapi.internal.proton
/** /**
* Returns all the CRL distribution points in the certificate as [URI]s along with the CRL issuer names, if any. * Returns all the CRL distribution points in the certificate as [URI]s along with the CRL issuer names, if any.
*/ */
@Suppress("ComplexMethod")
fun X509Certificate.distributionPoints(): Map<URI, List<X500Principal>?> { fun X509Certificate.distributionPoints(): Map<URI, List<X500Principal>?> {
logger.debug { "Checking CRLDPs for $subjectX500Principal" } logger.debug { "Checking CRLDPs for $subjectX500Principal" }
@ -117,6 +119,14 @@ fun certPathToString(certPath: Array<out X509Certificate>?): String {
return certPath.joinToString(System.lineSeparator()) { " ${it.toSimpleString()}" } return certPath.joinToString(System.lineSeparator()) { " ${it.toSimpleString()}" }
} }
/**
* Create an executor for processing SSL handshake tasks asynchronously (see [SSLEngine.getDelegatedTask]). The max number of threads is 3,
* which is the typical number of CRLs expected in a Corda TLS cert path. The executor needs to be passed to the [SslHandler] constructor.
*/
fun sslDelegatedTaskExecutor(parentPoolName: String): ThreadPoolExecutor {
return namedThreadPoolExecutor(maxPoolSize = 3, poolName = "$parentPoolName-ssltask")
}
@VisibleForTesting @VisibleForTesting
class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() { class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() {
companion object { companion object {
@ -179,32 +189,11 @@ class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509Ex
} }
private object LoggingImmediateExecutor : Executor {
override fun execute(command: Runnable) {
val log = LoggerFactory.getLogger(javaClass)
@Suppress("TooGenericExceptionCaught", "MagicNumber") // log and rethrow all exceptions
try {
val commandName = command::class.qualifiedName?.let { "[$it]" } ?: ""
log.debug("Entering SSL command $commandName")
val elapsedTime = measureTimeMillis { command.run() }
log.debug("Exiting SSL command $elapsedTime millis")
if (elapsedTime > 100) {
log.info("Command: $commandName took $elapsedTime millis to execute")
}
}
catch (ex: Exception) {
log.error("Caught exception in SSL handler executor", ex)
throw ex
}
}
}
internal fun createClientSslHandler(target: NetworkHostAndPort, internal fun createClientSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>, expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory, keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler { trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory) val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine(target.host, target.port) val sslEngine = sslContext.createSSLEngine(target.host, target.port)
sslEngine.useClientMode = true sslEngine.useClientMode = true
@ -216,14 +205,15 @@ internal fun createClientSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters sslEngine.sslParameters = sslParameters
} }
return SslHandler(sslEngine, false, LoggingImmediateExecutor) return SslHandler(sslEngine, false, delegateTaskExecutor)
} }
internal fun createClientOpenSslHandler(target: NetworkHostAndPort, internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>, expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory, keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory, trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler { alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build() val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build()
val sslEngine = sslContext.newEngine(alloc, target.host, target.port) val sslEngine = sslContext.newEngine(alloc, target.host, target.port)
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray() sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
@ -233,12 +223,13 @@ internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters sslEngine.sslParameters = sslParameters
} }
return SslHandler(sslEngine, false, LoggingImmediateExecutor) return SslHandler(sslEngine, false, delegateTaskExecutor)
} }
internal fun createServerSslHandler(keyStore: CertificateStore, internal fun createServerSslHandler(keyStore: CertificateStore,
keyManagerFactory: KeyManagerFactory, keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler { trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory) val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine() val sslEngine = sslContext.createSSLEngine()
sslEngine.useClientMode = false sslEngine.useClientMode = false
@ -249,39 +240,29 @@ internal fun createServerSslHandler(keyStore: CertificateStore,
val sslParameters = sslEngine.sslParameters val sslParameters = sslEngine.sslParameters
sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore)) sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore))
sslEngine.sslParameters = sslParameters sslEngine.sslParameters = sslParameters
return SslHandler(sslEngine, false, LoggingImmediateExecutor) return SslHandler(sslEngine, false, delegateTaskExecutor)
} }
internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory, internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory, trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler { alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build() val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build()
val sslEngine = sslContext.newEngine(alloc) val sslEngine = sslContext.newEngine(alloc)
sslEngine.useClientMode = false sslEngine.useClientMode = false
return SslHandler(sslEngine, false, LoggingImmediateExecutor) return SslHandler(sslEngine, false, delegateTaskExecutor)
} }
fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SSLContext { fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory?): SSLContext {
val sslContext = SSLContext.getInstance("TLS") val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers val trustManagers = trustManagerFactory
val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java) ?.trustManagers
.map { LoggingTrustManagerWrapper(it) }.toTypedArray() ?.map { if (it is X509ExtendedTrustManager) LoggingTrustManagerWrapper(it) else it }
sslContext.init(keyManagers, trustManagers, newSecureRandom()) ?.toTypedArray()
sslContext.init(keyManagerFactory.keyManagers, trustManagers, newSecureRandom())
return sslContext return sslContext
} }
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore,
revocationConfig: RevocationConfig): CertPathTrustManagerParameters {
return initialiseTrustStoreAndEnableCrlChecking(trustStore, revocationConfig.createPKIXRevocationChecker())
}
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore,
revocationChecker: PKIXRevocationChecker): CertPathTrustManagerParameters {
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
return CertPathTrustManagerParameters(pkixParams)
}
/** /**
* Creates a special SNI handler used only when openSSL is used for AMQPServer * Creates a special SNI handler used only when openSSL is used for AMQPServer
*/ */
@ -296,14 +277,13 @@ internal fun createServerSNIOpenSniHandler(keyManagerFactoriesMap: Map<String, K
return SniHandler(mapping.build()) return SniHandler(mapping.build())
} }
@Suppress("SpreadOperator")
private fun getServerSslContextBuilder(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslContextBuilder { private fun getServerSslContextBuilder(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslContextBuilder {
return SslContextBuilder.forServer(keyManagerFactory) return SslContextBuilder.forServer(keyManagerFactory)
.sslProvider(SslProvider.OPENSSL) .sslProvider(SslProvider.OPENSSL)
.trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)) .trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory))
.clientAuth(ClientAuth.REQUIRE) .clientAuth(ClientAuth.REQUIRE)
.ciphers(ArtemisTcpTransport.CIPHER_SUITES) .ciphers(ArtemisTcpTransport.CIPHER_SUITES)
.protocols(*ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()) .protocols(ArtemisTcpTransport.TLS_VERSIONS)
} }
internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKeyManagerFactoryWrapper> { internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKeyManagerFactoryWrapper> {
@ -327,7 +307,38 @@ internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKe
// 2nd parameter `password` - the password for recovering keys in the KeyStore // 2nd parameter `password` - the password for recovering keys in the KeyStore
fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.entryPassword.toCharArray()) fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.entryPassword.toCharArray())
fun TrustManagerFactory.init(trustStore: CertificateStore) = init(trustStore.value.internal) fun keyManagerFactory(keyStore: CertificateStore): KeyManagerFactory {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
return keyManagerFactory
}
fun trustManagerFactory(trustStore: CertificateStore): TrustManagerFactory {
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore.value.internal)
return trustManagerFactory
}
fun trustManagerFactoryWithRevocation(trustStore: CertificateStore,
revocationConfig: RevocationConfig,
crlSource: CrlSource): TrustManagerFactory {
val revocationChecker = when (revocationConfig.mode) {
RevocationConfig.Mode.OFF -> AllowAllRevocationChecker
RevocationConfig.Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(revocationConfig.externalCrlSource) {
"externalCrlSource must be specfied for EXTERNAL_SOURCE"
}
CordaRevocationChecker(externalCrlSource, softFail = true)
}
RevocationConfig.Mode.SOFT_FAIL -> CordaRevocationChecker(crlSource, softFail = true)
RevocationConfig.Mode.HARD_FAIL -> CordaRevocationChecker(crlSource, softFail = false)
}
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(CertPathTrustManagerParameters(pkixParams))
return trustManagerFactory
}
/** /**
* Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target * Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target

View File

@ -5,6 +5,9 @@ import com.github.benmanes.caffeine.cache.LoadingCache
import net.corda.core.internal.readFully import net.corda.core.internal.readFully
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.nodeapi.internal.crypto.X509CertificateFactory import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.toSimpleString import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
@ -12,60 +15,71 @@ import net.corda.nodeapi.internal.protonwrapper.netty.distributionPoints
import java.net.URI import java.net.URI
import java.security.cert.X509CRL import java.security.cert.X509CRL
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit import java.time.Duration
import javax.security.auth.x500.X500Principal import javax.security.auth.x500.X500Principal
/** /**
* [CrlSource] which downloads CRLs from the distribution points in the X509 certificate. * [CrlSource] which downloads CRLs from the distribution points in the X509 certificate and caches them.
*/ */
@Suppress("TooGenericExceptionCaught") @Suppress("TooGenericExceptionCaught")
class CertDistPointCrlSource : CrlSource { class CertDistPointCrlSource(cacheSize: Long = DEFAULT_CACHE_SIZE,
cacheExpiry: Duration = DEFAULT_CACHE_EXPIRY,
private val connectTimeout: Duration = DEFAULT_CONNECT_TIMEOUT,
private val readTimeout: Duration = DEFAULT_READ_TIMEOUT) : CrlSource {
companion object { companion object {
private val logger = contextLogger() private val logger = contextLogger()
// The default SSL handshake timeout is 60s (DEFAULT_SSL_HANDSHAKE_TIMEOUT). Considering there are 3 CRLs endpoints to check in a // The default SSL handshake timeout is 60s (DEFAULT_SSL_HANDSHAKE_TIMEOUT). Considering there are 3 CRLs endpoints to check in a
// node handshake, we want to keep the total timeout within that. // node handshake, we want to keep the total timeout within that.
private const val DEFAULT_CONNECT_TIMEOUT = 9_000 private val DEFAULT_CONNECT_TIMEOUT = 9.seconds
private const val DEFAULT_READ_TIMEOUT = 9_000 private val DEFAULT_READ_TIMEOUT = 9.seconds
private const val DEFAULT_CACHE_SIZE = 185L // Same default as the JDK (URICertStore) private const val DEFAULT_CACHE_SIZE = 185L // Same default as the JDK (URICertStore)
private const val DEFAULT_CACHE_EXPIRY = 5 * 60 * 1000L private val DEFAULT_CACHE_EXPIRY = 5.minutes
private val cache: LoadingCache<URI, X509CRL> = Caffeine.newBuilder() val SINGLETON = CertDistPointCrlSource(
.maximumSize(java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE)) cacheSize = java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE),
.expireAfterWrite(java.lang.Long.getLong("net.corda.dpcrl.cache.expiry", DEFAULT_CACHE_EXPIRY), TimeUnit.MILLISECONDS) cacheExpiry = java.lang.Long.getLong("net.corda.dpcrl.cache.expiry")?.let(Duration::ofMillis) ?: DEFAULT_CACHE_EXPIRY,
.build(::retrieveCRL) connectTimeout = java.lang.Long.getLong("net.corda.dpcrl.connect.timeout")?.let(Duration::ofMillis) ?: DEFAULT_CONNECT_TIMEOUT,
readTimeout = java.lang.Long.getLong("net.corda.dpcrl.read.timeout")?.let(Duration::ofMillis) ?: DEFAULT_READ_TIMEOUT
)
}
private val connectTimeout = Integer.getInteger("net.corda.dpcrl.connect.timeout", DEFAULT_CONNECT_TIMEOUT) private val cache: LoadingCache<URI, X509CRL> = Caffeine.newBuilder()
private val readTimeout = Integer.getInteger("net.corda.dpcrl.read.timeout", DEFAULT_READ_TIMEOUT) .maximumSize(cacheSize)
.expireAfterWrite(cacheExpiry)
.build(::retrieveCRL)
private fun retrieveCRL(uri: URI): X509CRL { private fun retrieveCRL(uri: URI): X509CRL {
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
val bytes = try { val bytes = try {
val conn = uri.toURL().openConnection() val conn = uri.toURL().openConnection()
conn.connectTimeout = connectTimeout conn.connectTimeout = connectTimeout.toMillis().toInt()
conn.readTimeout = readTimeout conn.readTimeout = readTimeout.toMillis().toInt()
// Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes // Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes
// in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException // in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException
// into CRLException and drops the cause chain. // into CRLException and drops the cause chain.
conn.getInputStream().readFully() conn.getInputStream().readFully()
} catch (e: Exception) { } catch (e: Exception) {
if (logger.isDebugEnabled) { if (logger.isDebugEnabled) {
logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e) logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e)
}
throw e
} }
val duration = System.currentTimeMillis() - start throw e
val crl = try {
X509CertificateFactory().generateCRL(bytes.inputStream())
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Invalid CRL from $uri (${duration}ms)", e)
}
throw e
}
logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" }
return crl
} }
val duration = System.currentTimeMillis() - start
val crl = try {
X509CertificateFactory().generateCRL(bytes.inputStream())
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Invalid CRL from $uri (${duration}ms)", e)
}
throw e
}
logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" }
return crl
}
fun clearCache() {
cache.invalidateAll()
} }
override fun fetch(certificate: X509Certificate): Set<X509CRL> { override fun fetch(certificate: X509Certificate): Set<X509CRL> {

View File

@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.init import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
@ -161,11 +162,9 @@ class TlsDiffAlgorithmsTest(private val serverAlgo: String, private val clientAl
private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext { private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext {
return SSLContext.getInstance("TLS").apply { return SSLContext.getInstance("TLS").apply {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) val keyManagerFactory = keyManagerFactory(keyStore)
keyManagerFactory.init(keyStore)
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) val trustMgrFactory = trustManagerFactory(trustStore)
trustMgrFactory.init(trustStore)
val trustManagers = trustMgrFactory.trustManagers val trustManagers = trustMgrFactory.trustManagers
init(keyManagers, trustManagers, newSecureRandom()) init(keyManagers, trustManagers, newSecureRandom())
} }

View File

@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.init import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.junit.Ignore import org.junit.Ignore
import org.junit.Rule import org.junit.Rule
@ -18,7 +19,6 @@ import java.io.IOException
import java.net.InetAddress import java.net.InetAddress
import java.net.InetSocketAddress import java.net.InetSocketAddress
import javax.net.ssl.* import javax.net.ssl.*
import javax.net.ssl.SNIHostName
import kotlin.concurrent.thread import kotlin.concurrent.thread
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFalse import kotlin.test.assertFalse
@ -209,11 +209,9 @@ class TlsDiffProtocolsTest(private val serverAlgo: String, private val clientAlg
private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext { private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext {
return SSLContext.getInstance("TLS").apply { return SSLContext.getInstance("TLS").apply {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) val keyManagerFactory = keyManagerFactory(keyStore)
keyManagerFactory.init(keyStore)
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) val trustMgrFactory = trustManagerFactory(trustStore)
trustMgrFactory.init(trustStore)
val trustManagers = trustMgrFactory.trustManagers val trustManagers = trustMgrFactory.trustManagers
init(keyManagers, trustManagers, newSecureRandom()) init(keyManagers, trustManagers, newSecureRandom())
} }

View File

@ -1,5 +1,6 @@
package net.corda.nodeapi.internal.protonwrapper.netty package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.util.concurrent.ImmediateExecutor
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
@ -8,10 +9,9 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
import net.corda.testing.internal.fixedCrlSource
import org.junit.Test import org.junit.Test
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName import javax.net.ssl.SNIHostName
import javax.net.ssl.TrustManagerFactory
import kotlin.test.assertEquals import kotlin.test.assertEquals
class SSLHelperTest { class SSLHelperTest {
@ -20,15 +20,21 @@ class SSLHelperTest {
val legalName = CordaX500Name("Test", "London", "GB") val legalName = CordaX500Name("Test", "London", "GB")
val sslConfig = configureTestSSL(legalName) val sslConfig = configureTestSSL(legalName)
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
val keyStore = sslConfig.keyStore val trustManagerFactory = trustManagerFactoryWithRevocation(
keyManagerFactory.init(CertificateStore.fromFile(keyStore.path, keyStore.storePassword, keyStore.entryPassword, false)) sslConfig.trustStore.get(),
val trustStore = sslConfig.trustStore RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL),
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(CertificateStore.fromFile(trustStore.path, trustStore.storePassword, trustStore.entryPassword, false), RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL))) fixedCrlSource(emptySet())
)
val sslHandler = createClientSslHandler(NetworkHostAndPort("localhost", 1234), setOf(legalName), keyManagerFactory, trustManagerFactory) val sslHandler = createClientSslHandler(
NetworkHostAndPort("localhost", 1234),
setOf(legalName),
keyManagerFactory,
trustManagerFactory,
ImmediateExecutor.INSTANCE
)
val legalNameHash = SecureHash.sha256(legalName.toString()).toString().take(32).toLowerCase() val legalNameHash = SecureHash.sha256(legalName.toString()).toString().take(32).toLowerCase()
// These hardcoded values must not be changed, something is broken if you have to change these hardcoded values. // These hardcoded values must not be changed, something is broken if you have to change these hardcoded values.

View File

@ -2,15 +2,13 @@ package net.corda.nodeapi.internal.revocation
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.createDevNodeCa import net.corda.nodeapi.internal.DEV_INTERMEDIATE_CA
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.node.internal.network.CrlServer import net.corda.testing.node.internal.network.CrlServer
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.jce.provider.BouncyCastleProvider import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import java.math.BigInteger
class CertDistPointCrlSourceTest { class CertDistPointCrlSourceTest {
private lateinit var crlServer: CrlServer private lateinit var crlServer: CrlServer
@ -39,13 +37,14 @@ class CertDistPointCrlSourceTest {
assertThat(single().revokedCertificates).isNull() assertThat(single().revokedCertificates).isNull()
} }
val nodeCaCert = crlServer.replaceNodeCertDistPoint(createDevNodeCa(crlServer.intermediateCa, ALICE_NAME).certificate) crlSource.clearCache()
crlServer.revokedNodeCerts += listOf(BigInteger.ONE, BigInteger.TEN) crlServer.revokedIntermediateCerts += DEV_INTERMEDIATE_CA.certificate
with(crlSource.fetch(nodeCaCert)) { // Use a different cert to avoid the cache with(crlSource.fetch(crlServer.intermediateCa.certificate)) {
assertThat(size).isEqualTo(1) assertThat(size).isEqualTo(1)
val revokedCertificates = single().revokedCertificates val revokedCertificates = single().revokedCertificates
assertThat(revokedCertificates.map { it.serialNumber }).containsExactlyInAnyOrder(BigInteger.ONE, BigInteger.TEN) // This also tests clearCache() works.
assertThat(revokedCertificates.map { it.serialNumber }).containsExactly(DEV_INTERMEDIATE_CA.certificate.serialNumber)
} }
} }
} }

View File

@ -5,7 +5,7 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource import net.corda.testing.internal.fixedCrlSource
import org.bouncycastle.jcajce.provider.asymmetric.x509.CertificateFactory import org.bouncycastle.jcajce.provider.asymmetric.x509.CertificateFactory
import org.junit.Test import org.junit.Test
import java.math.BigInteger import java.math.BigInteger
@ -41,10 +41,8 @@ class CordaRevocationCheckerTest {
val resourceAsStream = javaClass.getResourceAsStream("/net/corda/nodeapi/internal/protonwrapper/netty/doorman.crl") val resourceAsStream = javaClass.getResourceAsStream("/net/corda/nodeapi/internal/protonwrapper/netty/doorman.crl")
val crl = CertificateFactory().engineGenerateCRL(resourceAsStream) as X509CRL val crl = CertificateFactory().engineGenerateCRL(resourceAsStream) as X509CRL
val crlSource = object : CrlSource { val checker = CordaRevocationChecker(
override fun fetch(certificate: X509Certificate): Set<X509CRL> = setOf(crl) crlSource = fixedCrlSource(setOf(crl)),
}
val checker = CordaRevocationChecker(crlSource,
softFail = true, softFail = true,
dateSource = { Date.from(date.atStartOfDay().toInstant(ZoneOffset.UTC)) } dateSource = { Date.from(date.atStartOfDay().toInstant(ZoneOffset.UTC)) }
) )

View File

@ -1,20 +1,16 @@
package net.corda.node.internal.artemis package net.corda.nodeapi.internal.revocation
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.utilities.days import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.node.internal.artemis.CertificateChainCheckPolicy.RevocationCheck import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.CertificateType import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509KeyStore
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.testing.core.createCRL
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x509.CRLReason
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.ExtensionsGenerator
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.IssuingDistributionPoint
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.junit.Before import org.junit.Before
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
@ -22,15 +18,18 @@ import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.junit.runners.Parameterized import org.junit.runners.Parameterized
import java.io.File import java.io.File
import java.security.KeyPair
import java.security.KeyStore import java.security.KeyStore
import java.security.PrivateKey import java.security.PrivateKey
import java.security.cert.CertificateException
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.* import java.util.*
import javax.net.ssl.X509TrustManager
import javax.security.auth.x500.X500Principal import javax.security.auth.x500.X500Principal
import kotlin.test.assertFails import kotlin.test.assertFailsWith
@RunWith(Parameterized::class) @RunWith(Parameterized::class)
class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { class RevocationTest(private val revocationMode: RevocationConfig.Mode) {
companion object { companion object {
@JvmStatic @JvmStatic
@Parameterized.Parameters(name = "revocationMode = {0}") @Parameterized.Parameters(name = "revocationMode = {0}")
@ -45,8 +44,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
private lateinit var doormanCRL: File private lateinit var doormanCRL: File
private lateinit var tlsCRL: File private lateinit var tlsCRL: File
private val keyStore = KeyStore.getInstance("JKS") private lateinit var trustManager: X509TrustManager
private val trustStore = KeyStore.getInstance("JKS")
private val rootKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) private val rootKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private val tlsCRLIssuerKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) private val tlsCRLIssuerKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
@ -61,7 +59,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
private lateinit var tlsCert: X509Certificate private lateinit var tlsCert: X509Certificate
private val chain private val chain
get() = listOf(tlsCert, nodeCACert, doormanCert, rootCert).toTypedArray() get() = arrayOf(tlsCert, nodeCACert, doormanCert, rootCert)
@Before @Before
fun before() { fun before() {
@ -72,10 +70,18 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
rootCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=root"), rootKeyPair) rootCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=root"), rootKeyPair)
tlsCRLIssuerCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=issuer"), tlsCRLIssuerKeyPair) tlsCRLIssuerCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=issuer"), tlsCRLIssuerKeyPair)
val trustStore = KeyStore.getInstance("JKS")
trustStore.load(null, null) trustStore.load(null, null)
trustStore.setCertificateEntry("cordatlscrlsigner", tlsCRLIssuerCert) trustStore.setCertificateEntry("cordatlscrlsigner", tlsCRLIssuerCert)
trustStore.setCertificateEntry("cordarootca", rootCert) trustStore.setCertificateEntry("cordarootca", rootCert)
val trustManagerFactory = trustManagerFactoryWithRevocation(
CertificateStore.of(X509KeyStore(trustStore, "pass"), "pass", "pass"),
RevocationConfigImpl(revocationMode),
CertDistPointCrlSource()
)
trustManager = trustManagerFactory.trustManagers.single() as X509TrustManager
doormanCert = X509Utilities.createCertificate( doormanCert = X509Utilities.createCertificate(
CertificateType.INTERMEDIATE_CA, rootCert, rootKeyPair, X500Principal("CN=doorman"), doormanKeyPair.public, CertificateType.INTERMEDIATE_CA, rootCert, rootKeyPair, X500Principal("CN=doorman"), doormanKeyPair.public,
crlDistPoint = rootCRL.toURI().toString() crlDistPoint = rootCRL.toURI().toString()
@ -89,43 +95,34 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
) )
rootCRL.createCRL(rootCert, rootKeyPair.private, false) rootCRL.writeCRL(rootCert, rootKeyPair.private, false)
doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false) doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false)
tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true) tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true)
} }
private fun File.createCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) { private fun File.writeCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) {
val builder = JcaX509v2CRLBuilder(certificate.subjectX500Principal, Date()) val crl = createCRL(
builder.setNextUpdate(Date.from(Date().toInstant() + 7.days)) CertificateAndKeyPair(certificate, KeyPair(certificate.publicKey, privateKey)),
builder.addExtension(Extension.issuingDistributionPoint, true, IssuingDistributionPoint(null, indirect, false)) revoked.asList(),
revoked.forEach { indirect = indirect
val extensionsGenerator = ExtensionsGenerator() )
extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(CRLReason.keyCompromise)) writeBytes(crl.encoded)
// Certificate issuer is required for indirect CRL
val certificateIssuerName = X500Name.getInstance(it.issuerX500Principal.encoded)
extensionsGenerator.addExtension(Extension.certificateIssuer, true, GeneralNames(GeneralName(certificateIssuerName)))
builder.addCRLEntry(it.serialNumber, Date(), extensionsGenerator.generate())
}
val holder = builder.build(JcaContentSignerBuilder("SHA256withECDSA").setProvider(Crypto.findProvider("BC")).build(privateKey))
outputStream().use { it.write(holder.encoded) }
} }
private fun assertFailsFor(vararg modes: RevocationConfig.Mode, block: () -> Unit) { private fun assertFailsFor(vararg modes: RevocationConfig.Mode) {
if (revocationMode in modes) assertFails(block) else block() if (revocationMode in modes) assertFailsWith(CertificateException::class, ::doRevocationCheck) else doRevocationCheck()
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
fun `ok with empty CRLs`() { fun `ok with empty CRLs`() {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) doRevocationCheck()
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
fun `soft fail with revoked TLS certificate`() { fun `soft fail with revoked TLS certificate`() {
tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert) tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert)
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) { assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -136,9 +133,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
) )
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -148,9 +143,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name("CN=unknown") crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name("CN=unknown")
) )
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -160,9 +153,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = tlsCRL.toURI().toString() crlDistPoint = tlsCRL.toURI().toString()
) )
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -172,18 +163,16 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=other"), otherKeyPair.public, CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=other"), otherKeyPair.public,
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
) )
tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert) tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) doRevocationCheck()
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
fun `soft fail with revoked node CA certificate`() { fun `soft fail with revoked node CA certificate`() {
doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, nodeCACert) doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, nodeCACert)
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) { assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -193,9 +182,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/doorman" crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/doorman"
) )
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -205,8 +192,12 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=other"), otherKeyPair.public, CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=other"), otherKeyPair.public,
crlDistPoint = doormanCRL.toURI().toString() crlDistPoint = doormanCRL.toURI().toString()
) )
doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, otherCert) doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, otherCert)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) doRevocationCheck()
}
private fun doRevocationCheck() {
trustManager.checkClientTrusted(chain, "ECDHE_ECDSA")
} }
} }

View File

@ -269,8 +269,6 @@ tasks.register('integrationTest', Test) {
testClassesDirs = sourceSets.integrationTest.output.classesDirs testClassesDirs = sourceSets.integrationTest.output.classesDirs
classpath = sourceSets.integrationTest.runtimeClasspath classpath = sourceSets.integrationTest.runtimeClasspath
maxParallelForks = (System.env.CORDA_NODE_INT_TESTING_FORKS == null) ? 1 : "$System.env.CORDA_NODE_INT_TESTING_FORKS".toInteger() maxParallelForks = (System.env.CORDA_NODE_INT_TESTING_FORKS == null) ? 1 : "$System.env.CORDA_NODE_INT_TESTING_FORKS".toInteger()
// CertificateRevocationListNodeTests
systemProperty 'net.corda.dpcrl.connect.timeout', '4000'
} }
tasks.register('slowIntegrationTest', Test) { tasks.register('slowIntegrationTest', Test) {

View File

@ -14,12 +14,15 @@ import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.init import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.initialiseTrustStoreAndEnableCrlChecking import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.testing.internal.fixedCrlSource
import org.junit.Assume.assumeFalse import org.junit.Assume.assumeFalse
import org.junit.Before import org.junit.Before
import org.junit.Rule import org.junit.Rule
@ -96,11 +99,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) {
override val maxMessageSize: Int = MAX_MESSAGE_SIZE override val maxMessageSize: Int = MAX_MESSAGE_SIZE
} }
serverKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) serverKeyManagerFactory = keyManagerFactory(keyStore)
serverTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
serverKeyManagerFactory.init(keyStore) serverTrustManagerFactory = trustManagerFactoryWithRevocation(
serverTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(serverAmqpConfig.trustStore, serverAmqpConfig.revocationConfig)) serverAmqpConfig.trustStore,
RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL),
fixedCrlSource(emptySet())
)
} }
private fun setupClientCertificates() { private fun setupClientCertificates() {
@ -127,11 +132,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) {
override val sslHandshakeTimeout: Duration = 3.seconds override val sslHandshakeTimeout: Duration = 3.seconds
} }
clientKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) clientKeyManagerFactory = keyManagerFactory(keyStore)
clientTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
clientKeyManagerFactory.init(keyStore) clientTrustManagerFactory = trustManagerFactoryWithRevocation(
clientTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(clientAmqpConfig.trustStore, clientAmqpConfig.revocationConfig)) clientAmqpConfig.trustStore,
RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL),
fixedCrlSource(emptySet())
)
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)

View File

@ -1,3 +1,5 @@
@file:Suppress("LongParameterList")
package net.corda.node.amqp package net.corda.node.amqp
import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.doReturn
@ -5,10 +7,10 @@ import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div import net.corda.core.internal.div
import net.corda.core.internal.rootCause
import net.corda.core.internal.times import net.corda.core.internal.times
import net.corda.core.toFuture
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.coretesting.internal.rigorousMock import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
@ -18,64 +20,68 @@ import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.config.CertificateStoreSupplier import net.corda.nodeapi.internal.config.CertificateStoreSupplier
import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_CA
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionChange
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.CHARLIE_NAME import net.corda.testing.core.CHARLIE_NAME
import net.corda.testing.core.MAX_MESSAGE_SIZE import net.corda.testing.core.MAX_MESSAGE_SIZE
import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.testing.node.internal.network.CrlServer import net.corda.testing.node.internal.network.CrlServer
import net.corda.testing.node.internal.network.CrlServer.Companion.EMPTY_CRL import net.corda.testing.node.internal.network.CrlServer.Companion.EMPTY_CRL
import net.corda.testing.node.internal.network.CrlServer.Companion.FORBIDDEN_CRL
import net.corda.testing.node.internal.network.CrlServer.Companion.NODE_CRL import net.corda.testing.node.internal.network.CrlServer.Companion.NODE_CRL
import net.corda.testing.node.internal.network.CrlServer.Companion.withCrlDistPoint import net.corda.testing.node.internal.network.CrlServer.Companion.withCrlDistPoint
import org.apache.activemq.artemis.api.core.QueueConfiguration import org.apache.activemq.artemis.api.core.QueueConfiguration
import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.RoutingType
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import org.bouncycastle.jce.provider.BouncyCastleProvider import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.rules.TemporaryFolder import org.junit.rules.TemporaryFolder
import java.net.SocketTimeoutException import java.io.Closeable
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.time.Duration import java.time.Duration
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.assertEquals import java.util.stream.IntStream
@Suppress("LongParameterList") abstract class AbstractServerRevocationTest {
class CertificateRevocationListNodeTests {
@Rule @Rule
@JvmField @JvmField
val temporaryFolder = TemporaryFolder() val temporaryFolder = TemporaryFolder()
private val portAllocation = incrementalPortAllocation() private val portAllocation = incrementalPortAllocation()
private val serverPort = portAllocation.nextPort() protected val serverPort = portAllocation.nextPort()
private lateinit var crlServer: CrlServer protected lateinit var crlServer: CrlServer
private lateinit var amqpServer: AMQPServer private val amqpClients = ArrayList<AMQPClient>()
private lateinit var amqpClient: AMQPClient
private abstract class AbstractNodeConfiguration : NodeConfiguration protected lateinit var defaultCrlDistPoints: CrlDistPoints
protected abstract class AbstractNodeConfiguration : NodeConfiguration
companion object { companion object {
private val unreachableIpCounter = AtomicInteger(1) private val unreachableIpCounter = AtomicInteger(1)
private val crlConnectTimeout = Duration.ofMillis(System.getProperty("net.corda.dpcrl.connect.timeout").toLong()) val crlConnectTimeout = 2.seconds
/** /**
* Use this method to get a unqiue unreachable IP address. Subsequent uses of the same IP for connection timeout testing purposes * Use this method to get a unqiue unreachable IP address. Subsequent uses of the same IP for connection timeout testing purposes
* may not work as the OS process may cache the timeout result. * may not work as the OS process may cache the timeout result.
*/ */
private fun newUnreachableIpAddress(): String { private fun newUnreachableIpAddress(): NetworkHostAndPort {
check(unreachableIpCounter.get() != 255) check(unreachableIpCounter.get() != 255)
return "10.255.255.${unreachableIpCounter.getAndIncrement()}" return NetworkHostAndPort("10.255.255", unreachableIpCounter.getAndIncrement())
} }
} }
@ -85,252 +91,190 @@ class CertificateRevocationListNodeTests {
Crypto.findProvider(BouncyCastleProvider.PROVIDER_NAME) Crypto.findProvider(BouncyCastleProvider.PROVIDER_NAME)
crlServer = CrlServer(NetworkHostAndPort("localhost", 0)) crlServer = CrlServer(NetworkHostAndPort("localhost", 0))
crlServer.start() crlServer.start()
defaultCrlDistPoints = CrlDistPoints(crlServer.hostAndPort)
} }
@After @After
fun tearDown() { fun tearDown() {
if (::amqpClient.isInitialized) { amqpClients.parallelStream().forEach(AMQPClient::close)
amqpClient.close()
}
if (::amqpServer.isInitialized) {
amqpServer.close()
}
if (::crlServer.isInitialized) { if (::crlServer.isInitialized) {
crlServer.close() crlServer.close()
} }
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection works and soft fail is enabled`() { fun `connection succeeds when soft fail is enabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
expectedConnectStatus = true expectedConnectedStatus = true
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection works and soft fail is disabled`() { fun `connection succeeds when soft fail is disabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = false, crlCheckSoftFail = false,
expectedConnectStatus = true expectedConnectedStatus = true
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when client's certificate is revoked and soft fail is enabled`() { fun `connection fails when client's certificate is revoked and soft fail is enabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
revokeClientCert = true, revokeClientCert = true,
expectedConnectStatus = false expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when client's certificate is revoked and soft fail is disabled`() { fun `connection fails when client's certificate is revoked and soft fail is disabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = false, crlCheckSoftFail = false,
revokeClientCert = true, revokeClientCert = true,
expectedConnectStatus = false expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when servers's certificate is revoked and soft fail is enabled`() { fun `connection fails when server's certificate is revoked and soft fail is enabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
revokeServerCert = true, revokeServerCert = true,
expectedConnectStatus = false expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when servers's certificate is revoked and soft fail is disabled`() { fun `connection fails when server's certificate is revoked and soft fail is disabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = false, crlCheckSoftFail = false,
revokeServerCert = true, revokeServerCert = true,
expectedConnectStatus = false expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL cannot be obtained and soft fail is enabled`() { fun `connection succeeds when CRL cannot be obtained and soft fail is enabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/invalid.crl", clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"),
expectedConnectStatus = true expectedConnectedStatus = true
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when CRL cannot be obtained and soft fail is disabled`() { fun `connection fails when CRL cannot be obtained and soft fail is disabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = false, crlCheckSoftFail = false,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/invalid.crl", clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"),
expectedConnectStatus = false expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL is not defined and soft fail is enabled`() { fun `connection succeeds when CRL is not defined for node CA cert and soft fail is enabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
nodeCrlDistPoint = null, clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null),
expectedConnectStatus = true expectedConnectedStatus = true
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when CRL is not defined and soft fail is disabled`() { fun `connection fails when CRL is not defined for node CA cert and soft fail is disabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = false, crlCheckSoftFail = false,
nodeCrlDistPoint = null, clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null),
expectedConnectStatus = false expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL retrieval is forbidden and soft fail is enabled`() { fun `connection succeeds when CRL is not defined for TLS cert and soft fail is enabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL", clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null),
expectedConnectStatus = true expectedConnectedStatus = true
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() { fun `connection fails when CRL is not defined for TLS cert and soft fail is disabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = false,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl", clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null),
sslHandshakeTimeout = crlConnectTimeout * 3, expectedConnectedStatus = false
expectedConnectStatus = true
) )
val timeoutExceptions = (amqpServer.softFailExceptions + amqpClient.softFailExceptions)
.map { it.rootCause }
.filterIsInstance<SocketTimeoutException>()
assertThat(timeoutExceptions).isNotEmpty
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() { fun `connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true,
sslHandshakeTimeout = crlConnectTimeout * 4,
clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()),
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() {
verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout / 2, sslHandshakeTimeout = crlConnectTimeout / 2,
expectedConnectStatus = false clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()),
expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout = 300_000)
fun `verify CRL algorithms`() { fun `influx of new clients during CRL endpoint downtime does not cause existing connections to drop`() {
val crl = crlServer.createRevocationList( val serverCrlSource = CertDistPointCrlSource()
"SHA256withECDSA", // Start the server and verify the first client has connected
crlServer.rootCa, val firstClientConnectionChangeStatus = verifyConnection(
EMPTY_CRL, crlCheckSoftFail = true,
true, crlSource = serverCrlSource,
emptyList() // In general, N remoting threads will naturally support N-1 new handshaking clients plus one thread for heartbeating with
// existing clients. The trick is to make sure at least N new clients are also supported.
remotingThreads = 2,
expectedConnectedStatus = true
) )
// This should pass.
crl.verify(crlServer.rootCa.keyPair.public)
// Try changing the algorithm to EC will fail. // Now simulate the CRL endpoint becoming very slow/unreachable
assertThatIllegalArgumentException().isThrownBy { crlServer.delay = 10.minutes
crlServer.createRevocationList( // And pretend enough time has elapsed that the cached CRLs have expired and need downloading again
"EC", serverCrlSource.clearCache()
crlServer.rootCa,
EMPTY_CRL, // Now a bunch of new clients have arrived and want to handshake with the server, which will potentially cause the server's Netty
true, // threads to be tied up in trying to download the CRLs.
emptyList() IntStream.range(0, 2).parallel().forEach { clientIndex ->
val (newClient, _) = createAMQPClient(
serverPort,
crlCheckSoftFail = true,
legalName = CordaX500Name("NewClient$clientIndex", "London", "GB"),
crlDistPoints = defaultCrlDistPoints
) )
}.withMessage("Unknown signature type requested: EC") newClient.start()
}
// Make sure there are no further connection change updates, i.e. the first client stays connected throughout this whole saga
assertThat(firstClientConnectionChangeStatus.poll(30, TimeUnit.SECONDS)).isNull()
} }
@Test(timeout = 300_000) protected abstract fun verifyConnection(crlCheckSoftFail: Boolean,
fun `Artemis server connection succeeds with soft fail CRL check`() { crlSource: CertDistPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout),
verifyArtemisConnection( sslHandshakeTimeout: Duration? = null,
crlCheckSoftFail = true, remotingThreads: Int? = null,
crlCheckArtemisServer = true, clientCrlDistPoints: CrlDistPoints = defaultCrlDistPoints,
expectedStatus = MessageStatus.Acknowledged revokeClientCert: Boolean = false,
) revokeServerCert: Boolean = false,
} expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange>
@Test(timeout = 300_000) protected fun createAMQPClient(targetPort: Int,
fun `Artemis server connection succeeds with hard fail CRL check`() { crlCheckSoftFail: Boolean,
verifyArtemisConnection( legalName: CordaX500Name,
crlCheckSoftFail = false, crlDistPoints: CrlDistPoints): Pair<AMQPClient, X509Certificate> {
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged
)
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with soft fail CRL check on unavailable URL`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL"
)
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with soft fail CRL check on unreachable URL if CRL timeout is within SSL handshake timeout`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout * 3
)
}
@Test(timeout = 300_000)
fun `Artemis server connection fails with soft fail CRL check on unreachable URL if CRL timeout is not within SSL handshake timeout`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedConnected = false,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout / 2
)
}
@Test(timeout = 300_000)
fun `Artemis server connection fails with hard fail CRL check on unavailable URL`() {
verifyArtemisConnection(
crlCheckSoftFail = false,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Rejected,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL"
)
}
@Test(timeout = 300_000)
fun `Artemis server connection fails with soft fail CRL check on revoked node certificate`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Rejected,
revokedNodeCert = true
)
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with disabled CRL check on revoked node certificate`() {
verifyArtemisConnection(
crlCheckSoftFail = false,
crlCheckArtemisServer = false,
expectedStatus = MessageStatus.Acknowledged,
revokedNodeCert = true
)
}
private fun createAMQPClient(targetPort: Int,
crlCheckSoftFail: Boolean,
legalName: CordaX500Name,
nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
tlsCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$EMPTY_CRL",
maxMessageSize: Int = MAX_MESSAGE_SIZE): X509Certificate {
val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation
val certificatesDirectory = baseDirectory / "certificates" val certificatesDirectory = baseDirectory / "certificates"
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory)
@ -344,31 +288,128 @@ class CertificateRevocationListNodeTests {
doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail
} }
clientConfig.configureWithDevSSLCertificate() clientConfig.configureWithDevSSLCertificate()
val nodeCert = recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, tlsCrlDistPoint) val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer)
val keyStore = clientConfig.p2pSslOptions.keyStore.get() val keyStore = clientConfig.p2pSslOptions.keyStore.get()
val amqpConfig = object : AMQPConfiguration { val amqpConfig = object : AMQPConfiguration {
override val keyStore = keyStore override val keyStore = keyStore
override val trustStore = clientConfig.p2pSslOptions.trustStore.get() override val trustStore = clientConfig.p2pSslOptions.trustStore.get()
override val maxMessageSize: Int = maxMessageSize override val maxMessageSize: Int = MAX_MESSAGE_SIZE
override val trace: Boolean = true
} }
amqpClient = AMQPClient( val amqpClient = AMQPClient(
listOf(NetworkHostAndPort("localhost", targetPort)), listOf(NetworkHostAndPort("localhost", targetPort)),
setOf(CHARLIE_NAME), setOf(CHARLIE_NAME),
amqpConfig, amqpConfig,
threadPoolName = legalName.organisation threadPoolName = legalName.organisation,
distPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout)
) )
amqpClients += amqpClient
return Pair(amqpClient, nodeCert)
}
return nodeCert protected fun AMQPClient.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange> {
val connectionChangeStatus = LinkedBlockingQueue<ConnectionChange>()
onConnection.subscribe { connectionChangeStatus.add(it) }
start()
assertThat(connectionChangeStatus.take().connected).isEqualTo(expectedConnectedStatus)
return connectionChangeStatus
}
protected data class CrlDistPoints(val crlServerAddress: NetworkHostAndPort,
val nodeCa: String? = NODE_CRL,
val tls: String? = EMPTY_CRL) {
private val nodeCaCertCrlDistPoint: String? get() = nodeCa?.let { "http://$crlServerAddress/crl/$it" }
private val tlsCertCrlDistPoint: String? get() = tls?.let { "http://$crlServerAddress/crl/$it" }
fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier,
p2pSslConfiguration: MutualSslConfiguration,
crlServer: CrlServer): X509Certificate {
val nodeKeyStore = signingCertificateStore.get()
val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_CA, nodeKeyStore.entryPassword) }
val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCertCrlDistPoint)
val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) +
nodeKeyStore.query { getCertificateChain(CORDA_CLIENT_CA) }.drop(2)
nodeKeyStore.update {
internal.deleteEntry(CORDA_CLIENT_CA)
}
nodeKeyStore.update {
setPrivateKey(CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword)
}
val sslKeyStore = p2pSslConfiguration.keyStore.get()
val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_TLS, sslKeyStore.entryPassword) }
val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCertCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal)
val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) +
sslKeyStore.query { getCertificateChain(CORDA_CLIENT_TLS) }.drop(3)
sslKeyStore.update {
internal.deleteEntry(CORDA_CLIENT_TLS)
}
sslKeyStore.update {
setPrivateKey(CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword)
}
return newNodeCert
}
}
}
class AMQPServerRevocationTest : AbstractServerRevocationTest() {
private lateinit var amqpServer: AMQPServer
@After
fun shutDown() {
if (::amqpServer.isInitialized) {
amqpServer.close()
}
}
override fun verifyConnection(crlCheckSoftFail: Boolean,
crlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?,
remotingThreads: Int?,
clientCrlDistPoints: CrlDistPoints,
revokeClientCert: Boolean,
revokeServerCert: Boolean,
expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange> {
val serverCert = createAMQPServer(
serverPort,
CHARLIE_NAME,
crlCheckSoftFail,
defaultCrlDistPoints,
crlSource,
sslHandshakeTimeout,
remotingThreads
)
if (revokeServerCert) {
crlServer.revokedNodeCerts.add(serverCert)
}
amqpServer.start()
amqpServer.onReceive.subscribe {
it.complete(true)
}
val (client, clientCert) = createAMQPClient(
serverPort,
crlCheckSoftFail = crlCheckSoftFail,
legalName = ALICE_NAME,
crlDistPoints = clientCrlDistPoints
)
if (revokeClientCert) {
crlServer.revokedNodeCerts.add(clientCert)
}
return client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus)
} }
private fun createAMQPServer(port: Int, private fun createAMQPServer(port: Int,
legalName: CordaX500Name, legalName: CordaX500Name,
crlCheckSoftFail: Boolean, crlCheckSoftFail: Boolean,
nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL", crlDistPoints: CrlDistPoints,
tlsCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$EMPTY_CRL", distPointCrlSource: CertDistPointCrlSource,
maxMessageSize: Int = MAX_MESSAGE_SIZE, sslHandshakeTimeout: Duration?,
sslHandshakeTimeout: Duration? = null): X509Certificate { remotingThreads: Int?): X509Certificate {
check(!::amqpServer.isInitialized) check(!::amqpServer.isInitialized)
val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation
val certificatesDirectory = baseDirectory / "certificates" val certificatesDirectory = baseDirectory / "certificates"
@ -382,92 +423,103 @@ class CertificateRevocationListNodeTests {
doReturn(signingCertificateStore).whenever(it).signingCertificateStore doReturn(signingCertificateStore).whenever(it).signingCertificateStore
} }
serverConfig.configureWithDevSSLCertificate() serverConfig.configureWithDevSSLCertificate()
val nodeCert = recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, tlsCrlDistPoint) val serverCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer)
val keyStore = serverConfig.p2pSslOptions.keyStore.get() val keyStore = serverConfig.p2pSslOptions.keyStore.get()
val amqpConfig = object : AMQPConfiguration { val amqpConfig = object : AMQPConfiguration {
override val keyStore = keyStore override val keyStore = keyStore
override val trustStore = serverConfig.p2pSslOptions.trustStore.get() override val trustStore = serverConfig.p2pSslOptions.trustStore.get()
override val revocationConfig = crlCheckSoftFail.toRevocationConfig() override val revocationConfig = crlCheckSoftFail.toRevocationConfig()
override val maxMessageSize: Int = maxMessageSize override val maxMessageSize: Int = MAX_MESSAGE_SIZE
override val sslHandshakeTimeout: Duration = sslHandshakeTimeout ?: super.sslHandshakeTimeout override val sslHandshakeTimeout: Duration = sslHandshakeTimeout ?: super.sslHandshakeTimeout
} }
amqpServer = AMQPServer("0.0.0.0", port, amqpConfig, threadPoolName = legalName.organisation) amqpServer = AMQPServer(
return nodeCert "0.0.0.0",
} port,
amqpConfig,
private fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier, threadPoolName = legalName.organisation,
p2pSslConfiguration: MutualSslConfiguration, distPointCrlSource = distPointCrlSource,
nodeCaCrlDistPoint: String?, remotingThreads = remotingThreads
tlsCrlDistPoint: String?): X509Certificate {
val nodeKeyStore = signingCertificateStore.get()
val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA, nodeKeyStore.entryPassword) }
val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCrlDistPoint)
val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) +
nodeKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_CA) }.drop(2)
nodeKeyStore.update {
internal.deleteEntry(X509Utilities.CORDA_CLIENT_CA)
}
nodeKeyStore.update {
setPrivateKey(X509Utilities.CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword)
}
val sslKeyStore = p2pSslConfiguration.keyStore.get()
val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_TLS, sslKeyStore.entryPassword) }
val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal)
val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) +
sslKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_TLS) }.drop(3)
sslKeyStore.update {
internal.deleteEntry(X509Utilities.CORDA_CLIENT_TLS)
}
sslKeyStore.update {
setPrivateKey(X509Utilities.CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword)
}
return newNodeCert
}
private fun verifyAMQPConnection(crlCheckSoftFail: Boolean,
nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
revokeServerCert: Boolean = false,
revokeClientCert: Boolean = false,
sslHandshakeTimeout: Duration? = null,
expectedConnectStatus: Boolean) {
val serverCert = createAMQPServer(
serverPort,
CHARLIE_NAME,
crlCheckSoftFail = crlCheckSoftFail,
nodeCrlDistPoint = nodeCrlDistPoint,
sslHandshakeTimeout = sslHandshakeTimeout
) )
if (revokeServerCert) { return serverCert
crlServer.revokedNodeCerts.add(serverCert.serialNumber) }
}
class ArtemisServerRevocationTest : AbstractServerRevocationTest() {
private lateinit var artemisNode: ArtemisNode
private var crlCheckArtemisServer = true
@After
fun shutDown() {
if (::artemisNode.isInitialized) {
artemisNode.close()
} }
amqpServer.start() }
amqpServer.onReceive.subscribe {
it.complete(true) @Test(timeout = 300_000)
} fun `connection succeeds with disabled CRL check on revoked node certificate`() {
val clientCert = createAMQPClient( crlCheckArtemisServer = false
verifyConnection(
crlCheckSoftFail = false,
revokeClientCert = true,
expectedConnectedStatus = true
)
}
override fun verifyConnection(crlCheckSoftFail: Boolean,
crlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?,
remotingThreads: Int?,
clientCrlDistPoints: CrlDistPoints,
revokeClientCert: Boolean,
revokeServerCert: Boolean,
expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange> {
val (client, clientCert) = createAMQPClient(
serverPort, serverPort,
crlCheckSoftFail = crlCheckSoftFail, crlCheckSoftFail = true,
legalName = ALICE_NAME, legalName = ALICE_NAME,
nodeCrlDistPoint = nodeCrlDistPoint crlDistPoints = clientCrlDistPoints
) )
if (revokeClientCert) { if (revokeClientCert) {
crlServer.revokedNodeCerts.add(clientCert.serialNumber) crlServer.revokedNodeCerts.add(clientCert)
} }
val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.start() val nodeCert = startArtemisNode(
val serverConnect = serverConnected.get() CHARLIE_NAME,
assertThat(serverConnect.connected).isEqualTo(expectedConnectStatus) crlCheckSoftFail,
defaultCrlDistPoints,
crlSource,
sslHandshakeTimeout,
remotingThreads
)
if (revokeServerCert) {
crlServer.revokedNodeCerts.add(nodeCert)
}
val queueName = "${P2P_PREFIX}Test"
artemisNode.client.started!!.session.createQueue(
QueueConfiguration(queueName).setRoutingType(RoutingType.ANYCAST).setAddress(queueName).setDurable(true)
)
val clientConnectionChangeStatus = client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus)
if (expectedConnectedStatus) {
val msg = client.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap())
client.write(msg)
assertThat(msg.onComplete.get()).isEqualTo(MessageStatus.Acknowledged)
}
return clientConnectionChangeStatus
} }
private fun createArtemisServerAndClient(legalName: CordaX500Name, private fun startArtemisNode(legalName: CordaX500Name,
crlCheckSoftFail: Boolean, crlCheckSoftFail: Boolean,
crlCheckArtemisServer: Boolean, crlDistPoints: CrlDistPoints,
nodeCrlDistPoint: String, distPointCrlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?): Pair<ArtemisMessagingServer, ArtemisMessagingClient> { sslHandshakeTimeout: Duration?,
val baseDirectory = temporaryFolder.root.toPath() / "artemis" remotingThreads: Int?): X509Certificate {
check(!::artemisNode.isInitialized)
val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation
val certificatesDirectory = baseDirectory / "certificates" val certificatesDirectory = baseDirectory / "certificates"
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, sslHandshakeTimeout = sslHandshakeTimeout) val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, sslHandshakeTimeout = sslHandshakeTimeout)
@ -483,62 +535,34 @@ class CertificateRevocationListNodeTests {
doReturn(crlCheckArtemisServer).whenever(it).crlCheckArtemisServer doReturn(crlCheckArtemisServer).whenever(it).crlCheckArtemisServer
} }
artemisConfig.configureWithDevSSLCertificate() artemisConfig.configureWithDevSSLCertificate()
recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, null) val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer)
val server = ArtemisMessagingServer( val server = ArtemisMessagingServer(
artemisConfig, artemisConfig,
artemisConfig.p2pAddress, artemisConfig.p2pAddress,
MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE,
threadPoolName = "${legalName.organisation}-server", threadPoolName = "${legalName.organisation}-server",
trace = true trace = true,
distPointCrlSource = distPointCrlSource,
remotingThreads = remotingThreads
) )
val client = ArtemisMessagingClient( val client = ArtemisMessagingClient(
artemisConfig.p2pSslOptions, artemisConfig.p2pSslOptions,
artemisConfig.p2pAddress, artemisConfig.p2pAddress,
MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE,
threadPoolName = "${legalName.organisation}-client", threadPoolName = "${legalName.organisation}-client"
trace = true
) )
server.start() server.start()
client.start() client.start()
return server to client val artemisNode = ArtemisNode(server, client)
this.artemisNode = artemisNode
return nodeCert
} }
private fun verifyArtemisConnection(crlCheckSoftFail: Boolean, private class ArtemisNode(val server: ArtemisMessagingServer, val client: ArtemisMessagingClient) : Closeable {
crlCheckArtemisServer: Boolean, override fun close() {
expectedConnected: Boolean = true, client.stop()
expectedStatus: MessageStatus? = null, server.close()
revokedNodeCert: Boolean = false,
nodeCrlDistPoint: String = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
sslHandshakeTimeout: Duration? = null) {
val queueName = P2P_PREFIX + "Test"
val (artemisServer, artemisClient) = createArtemisServerAndClient(
CHARLIE_NAME,
crlCheckSoftFail,
crlCheckArtemisServer,
nodeCrlDistPoint,
sslHandshakeTimeout
)
artemisServer.use {
artemisClient.started!!.session.createQueue(
QueueConfiguration(queueName).setRoutingType(RoutingType.ANYCAST).setAddress(queueName).setDurable(true)
)
val nodeCert = createAMQPClient(serverPort, true, ALICE_NAME, nodeCrlDistPoint)
if (revokedNodeCert) {
crlServer.revokedNodeCerts.add(nodeCert.serialNumber)
}
val clientConnected = amqpClient.onConnection.toFuture()
amqpClient.start()
val clientConnect = clientConnected.get()
assertThat(clientConnect.connected).isEqualTo(expectedConnected)
if (expectedConnected) {
val msg = amqpClient.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap())
amqpClient.write(msg)
assertEquals(expectedStatus, msg.onComplete.get())
}
artemisClient.stop()
} }
} }
} }

View File

@ -4,12 +4,15 @@ import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import io.netty.channel.EventLoopGroup import io.netty.channel.EventLoopGroup
import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.nio.NioEventLoopGroup
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.crypto.newSecureRandom import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div import net.corda.core.internal.div
import net.corda.core.toFuture import net.corda.core.toFuture
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.node.services.messaging.ArtemisMessagingServer import net.corda.node.services.messaging.ArtemisMessagingServer
@ -23,7 +26,9 @@ import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.nodeapi.internal.protonwrapper.netty.init import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
@ -31,9 +36,6 @@ import net.corda.testing.core.CHARLIE_NAME
import net.corda.testing.core.MAX_MESSAGE_SIZE import net.corda.testing.core.MAX_MESSAGE_SIZE
import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.testing.internal.createDevIntermediateCaCertPath import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import org.apache.activemq.artemis.api.core.QueueConfiguration import org.apache.activemq.artemis.api.core.QueueConfiguration
import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.RoutingType
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
@ -43,7 +45,11 @@ import org.junit.Test
import org.junit.rules.TemporaryFolder import org.junit.rules.TemporaryFolder
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import javax.net.ssl.* import javax.net.ssl.SSLContext
import javax.net.ssl.SSLHandshakeException
import javax.net.ssl.SSLParameters
import javax.net.ssl.SSLServerSocket
import javax.net.ssl.SSLSocket
import kotlin.concurrent.thread import kotlin.concurrent.thread
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
@ -145,15 +151,10 @@ class ProtonWrapperTests {
sslConfig.keyStore.get(true).also { it.registerDevP2pCertificates(ALICE_NAME, rootCa.certificate, intermediateCa) } sslConfig.keyStore.get(true).also { it.registerDevP2pCertificates(ALICE_NAME, rootCa.certificate, intermediateCa) }
sslConfig.createTrustStore(rootCa.certificate) sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val context = SSLContext.getInstance("TLS") val context = SSLContext.getInstance("TLS")
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
keyManagerFactory.init(keyStore)
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get())
trustMgrFactory.init(trustStore)
val trustManagers = trustMgrFactory.trustManagers val trustManagers = trustMgrFactory.trustManagers
context.init(keyManagers, trustManagers, newSecureRandom()) context.init(keyManagers, trustManagers, newSecureRandom())
@ -344,7 +345,7 @@ class ProtonWrapperTests {
amqpServer.use { amqpServer.use {
val connectionEvents = amqpServer.onConnection.toBlocking().iterator val connectionEvents = amqpServer.onConnection.toBlocking().iterator
amqpServer.start() amqpServer.start()
val sharedThreads = NioEventLoopGroup() val sharedThreads = NioEventLoopGroup(DefaultThreadFactory("sharedThreads"))
val amqpClient1 = createSharedThreadsClient(sharedThreads, 0) val amqpClient1 = createSharedThreadsClient(sharedThreads, 0)
val amqpClient2 = createSharedThreadsClient(sharedThreads, 1) val amqpClient2 = createSharedThreadsClient(sharedThreads, 1)
amqpClient1.start() amqpClient1.start()

View File

@ -3,6 +3,8 @@ package net.corda.services.messaging
import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.concurrent.openFuture
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.apache.qpid.jms.JmsConnectionFactory import org.apache.qpid.jms.JmsConnectionFactory
import org.apache.qpid.jms.meta.JmsConnectionInfo import org.apache.qpid.jms.meta.JmsConnectionInfo
import org.apache.qpid.jms.provider.Provider import org.apache.qpid.jms.provider.Provider
@ -24,9 +26,7 @@ import javax.jms.Connection
import javax.jms.Message import javax.jms.Message
import javax.jms.MessageProducer import javax.jms.MessageProducer
import javax.jms.Session import javax.jms.Session
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
/** /**
* Simple AMQP client connecting to broker using JMS. * Simple AMQP client connecting to broker using JMS.
@ -59,12 +59,8 @@ class SimpleAMQPClient(private val target: NetworkHostAndPort, private val confi
private lateinit var connection: Connection private lateinit var connection: Connection
private fun sslContext(): SSLContext { private fun sslContext(): SSLContext {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()).apply { val keyManagerFactory = keyManagerFactory(config.keyStore.get())
init(config.keyStore.get().value.internal, config.keyStore.entryPassword.toCharArray()) val trustManagerFactory = trustManagerFactory(config.trustStore.get())
}
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
init(config.trustStore.get().value.internal)
}
val sslContext = SSLContext.getInstance("TLS") val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers val trustManagers = trustManagerFactory.trustManagers

View File

@ -4,7 +4,6 @@ import co.paralleluniverse.fibers.instrument.Retransform
import com.codahale.metrics.MetricRegistry import com.codahale.metrics.MetricRegistry
import com.google.common.collect.MutableClassToInstanceMap import com.google.common.collect.MutableClassToInstanceMap
import com.google.common.util.concurrent.MoreExecutors import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.ThreadFactoryBuilder
import com.zaxxer.hikari.pool.HikariPool import com.zaxxer.hikari.pool.HikariPool
import net.corda.common.logging.errorReporting.NodeDatabaseErrors import net.corda.common.logging.errorReporting.NodeDatabaseErrors
import net.corda.confidential.SwapIdentitiesFlow import net.corda.confidential.SwapIdentitiesFlow
@ -67,6 +66,7 @@ import net.corda.core.toFuture
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.core.utilities.millis
import net.corda.core.utilities.minutes import net.corda.core.utilities.minutes
import net.corda.djvm.source.ApiSource import net.corda.djvm.source.ApiSource
import net.corda.djvm.source.EmptyApi import net.corda.djvm.source.EmptyApi
@ -165,6 +165,7 @@ import net.corda.nodeapi.internal.persistence.RestrictedEntityManager
import net.corda.nodeapi.internal.persistence.SchemaMigration import net.corda.nodeapi.internal.persistence.SchemaMigration
import net.corda.nodeapi.internal.persistence.contextDatabase import net.corda.nodeapi.internal.persistence.contextDatabase
import net.corda.nodeapi.internal.persistence.withoutDatabaseAccess import net.corda.nodeapi.internal.persistence.withoutDatabaseAccess
import net.corda.nodeapi.internal.namedThreadPoolExecutor
import org.apache.activemq.artemis.utils.ReusableLatch import org.apache.activemq.artemis.utils.ReusableLatch
import org.jolokia.jvmagent.JolokiaServer import org.jolokia.jvmagent.JolokiaServer
import org.jolokia.jvmagent.JolokiaServerConfig import org.jolokia.jvmagent.JolokiaServerConfig
@ -180,9 +181,6 @@ import java.util.ArrayList
import java.util.Properties import java.util.Properties
import java.util.concurrent.ExecutorService import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeUnit.MINUTES import java.util.concurrent.TimeUnit.MINUTES
import java.util.concurrent.TimeUnit.SECONDS import java.util.concurrent.TimeUnit.SECONDS
import java.util.function.Consumer import java.util.function.Consumer
@ -880,13 +878,12 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
} }
// Start with 1 thread and scale up to the configured thread pool size if needed // Start with 1 thread and scale up to the configured thread pool size if needed
// Parameters of [ThreadPoolExecutor] based on [Executors.newFixedThreadPool] // Parameters of [ThreadPoolExecutor] based on [Executors.newFixedThreadPool]
return ThreadPoolExecutor( return namedThreadPoolExecutor(
1, corePoolSize = 1,
numberOfThreads, maxPoolSize = numberOfThreads,
0L, idleKeepAlive = 0.millis,
TimeUnit.MILLISECONDS, poolName = "flow-external-operation-thread",
LinkedBlockingQueue<Runnable>(), daemonThreads = true
ThreadFactoryBuilder().setNameFormat("flow-external-operation-thread").setDaemon(true).build()
) )
} }

View File

@ -2,6 +2,7 @@ package net.corda.node.internal
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.cordapp.CordappProvider import net.corda.core.cordapp.CordappProvider
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.SerializedStateAndRef import net.corda.core.internal.SerializedStateAndRef
import net.corda.core.node.NetworkParameters import net.corda.core.node.NetworkParameters
import net.corda.core.node.ServicesForResolution import net.corda.core.node.ServicesForResolution
@ -9,8 +10,10 @@ import net.corda.core.node.services.AttachmentStorage
import net.corda.core.node.services.IdentityService import net.corda.core.node.services.IdentityService
import net.corda.core.node.services.NetworkParametersService import net.corda.core.node.services.NetworkParametersService
import net.corda.core.node.services.TransactionStorage import net.corda.core.node.services.TransactionStorage
import net.corda.core.transactions.BaseTransaction
import net.corda.core.transactions.ContractUpgradeWireTransaction import net.corda.core.transactions.ContractUpgradeWireTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.core.transactions.WireTransaction.Companion.resolveStateRefBinaryComponent import net.corda.core.transactions.WireTransaction.Companion.resolveStateRefBinaryComponent
@ -26,25 +29,23 @@ data class ServicesForResolutionImpl(
@Throws(TransactionResolutionException::class) @Throws(TransactionResolutionException::class)
override fun loadState(stateRef: StateRef): TransactionState<*> { override fun loadState(stateRef: StateRef): TransactionState<*> {
val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash) return toBaseTransaction(stateRef.txhash).outputs[stateRef.index]
return stx.resolveBaseTransaction(this).outputs[stateRef.index]
} }
@Throws(TransactionResolutionException::class) @Throws(TransactionResolutionException::class)
override fun loadStates(stateRefs: Set<StateRef>): Set<StateAndRef<ContractState>> { override fun loadStates(stateRefs: Set<StateRef>): Set<StateAndRef<ContractState>> {
return stateRefs.groupBy { it.txhash }.flatMap { val baseTxs = HashMap<SecureHash, BaseTransaction>()
val stx = validatedTransactions.getTransaction(it.key) ?: throw TransactionResolutionException(it.key) return stateRefs.mapTo(LinkedHashSet()) { stateRef ->
val baseTx = stx.resolveBaseTransaction(this) val baseTx = baseTxs.computeIfAbsent(stateRef.txhash, ::toBaseTransaction)
it.value.map { ref -> StateAndRef(baseTx.outputs[ref.index], ref) } StateAndRef(baseTx.outputs[stateRef.index], stateRef)
}.toSet() }
} }
@Throws(TransactionResolutionException::class, AttachmentResolutionException::class) @Throws(TransactionResolutionException::class, AttachmentResolutionException::class)
override fun loadContractAttachment(stateRef: StateRef): Attachment { override fun loadContractAttachment(stateRef: StateRef): Attachment {
// We may need to recursively chase transactions if there are notary changes. // We may need to recursively chase transactions if there are notary changes.
fun inner(stateRef: StateRef, forContractClassName: String?): Attachment { fun inner(stateRef: StateRef, forContractClassName: String?): Attachment {
val ctx = validatedTransactions.getTransaction(stateRef.txhash)?.coreTransaction val ctx = getSignedTransaction(stateRef.txhash).coreTransaction
?: throw TransactionResolutionException(stateRef.txhash)
when (ctx) { when (ctx) {
is WireTransaction -> { is WireTransaction -> {
val transactionState = ctx.outRef<ContractState>(stateRef.index).state val transactionState = ctx.outRef<ContractState>(stateRef.index).state
@ -69,4 +70,10 @@ data class ServicesForResolutionImpl(
} }
return inner(stateRef, null) return inner(stateRef, null)
} }
private fun toBaseTransaction(txhash: SecureHash): BaseTransaction = getSignedTransaction(txhash).resolveBaseTransaction(this)
private fun getSignedTransaction(txhash: SecureHash): SignedTransaction {
return validatedTransactions.getTransaction(txhash) ?: throw TransactionResolutionException(txhash)
}
} }

View File

@ -135,12 +135,12 @@ class BrokerJaasLoginModule : BaseBrokerJaasLoginModule() {
Pair(ArtemisMessagingComponent.NODE_RPC_USER, listOf(RolePrincipal(NODE_RPC_ROLE))) Pair(ArtemisMessagingComponent.NODE_RPC_USER, listOf(RolePrincipal(NODE_RPC_ROLE)))
} }
ArtemisMessagingComponent.PEER_USER -> { ArtemisMessagingComponent.PEER_USER -> {
requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." } val p2pJaasConfig = requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." }
requireTls(certificates) requireTls(certificates)
// This check is redundant as it was performed already during the SSL handshake // This check is redundant as it was performed already during the SSL handshake
CertificateChainCheckPolicy.RootMustMatch.createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates) CertificateChainCheckPolicy.RootMustMatch
CertificateChainCheckPolicy.RevocationCheck(p2pJaasConfig!!.revocationMode) .createCheck(p2pJaasConfig.keyStore, p2pJaasConfig.trustStore)
.createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates) .checkCertificateChain(certificates)
Pair(certificates.first().subjectDN.name, listOf(RolePrincipal(PEER_ROLE))) Pair(certificates.first().subjectDN.name, listOf(RolePrincipal(PEER_ROLE)))
} }
else -> { else -> {

View File

@ -2,17 +2,9 @@ package net.corda.node.internal.artemis
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.certPathToString
import java.security.KeyStore import java.security.KeyStore
import java.security.cert.CertPathValidator
import java.security.cert.CertPathValidatorException
import java.security.cert.CertificateException import java.security.cert.CertificateException
import java.security.cert.PKIXBuilderParameters
import java.security.cert.X509CertSelector
sealed class CertificateChainCheckPolicy { sealed class CertificateChainCheckPolicy {
companion object { companion object {
@ -92,33 +84,4 @@ sealed class CertificateChainCheckPolicy {
} }
} }
} }
class RevocationCheck(val revocationConfig: RevocationConfig) : CertificateChainCheckPolicy() {
constructor(revocationMode: RevocationConfig.Mode) : this(RevocationConfigImpl(revocationMode))
override fun createCheck(keyStore: KeyStore, trustStore: KeyStore): Check {
return object : Check {
override fun checkCertificateChain(theirChain: Array<java.security.cert.X509Certificate>) {
// Convert javax.security.cert.X509Certificate to java.security.cert.X509Certificate.
val chain = theirChain.map { X509CertificateFactory().generateCertificate(it.encoded.inputStream()) }
log.info("Check Client Certpath:\r\n${certPathToString(chain.toTypedArray())}")
// Drop the last certificate which must be a trusted root (validated by RootMustMatch).
// Assume that there is no more trusted roots (or corresponding public keys) in the remaining chain.
// See PKIXValidator.engineValidate() for reference implementation.
val certPath = X509Utilities.buildCertPath(chain.dropLast(1))
val certPathValidator = CertPathValidator.getInstance("PKIX")
val pkixRevocationChecker = revocationConfig.createPKIXRevocationChecker()
val params = PKIXBuilderParameters(trustStore, X509CertSelector())
params.addCertPathChecker(pkixRevocationChecker)
try {
certPathValidator.validate(certPath, params)
} catch (ex: CertPathValidatorException) {
log.error("Bad certificate path", ex)
throw ex
}
}
}
}
}
} }

View File

@ -7,9 +7,16 @@ import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.node.internal.artemis.* import net.corda.node.internal.artemis.ArtemisBroker
import net.corda.node.internal.artemis.BrokerAddresses
import net.corda.node.internal.artemis.BrokerJaasLoginModule
import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.NODE_P2P_ROLE import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.NODE_P2P_ROLE
import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.PEER_ROLE import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.PEER_ROLE
import net.corda.node.internal.artemis.NodeJaasConfig
import net.corda.node.internal.artemis.P2PJaasConfig
import net.corda.node.internal.artemis.SecureArtemisConfiguration
import net.corda.node.internal.artemis.UserValidationPlugin
import net.corda.node.internal.artemis.isBindingError
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.utilities.artemis.startSynchronously import net.corda.node.utilities.artemis.startSynchronously
import net.corda.nodeapi.internal.AmqpMessageSizeChecksInterceptor import net.corda.nodeapi.internal.AmqpMessageSizeChecksInterceptor
@ -21,7 +28,10 @@ import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.SECURITY_INVALIDATION_INTERVAL import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.SECURITY_INVALIDATION_INTERVAL
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pAcceptorTcpTransport import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pAcceptorTcpTransport
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.nodeapi.internal.requireOnDefaultFileSystem import net.corda.nodeapi.internal.requireOnDefaultFileSystem
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl
@ -32,9 +42,7 @@ import org.apache.activemq.artemis.core.security.Role
import org.apache.activemq.artemis.core.server.ActiveMQServer import org.apache.activemq.artemis.core.server.ActiveMQServer
import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl
import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager
import java.io.IOException
import java.lang.Long.max import java.lang.Long.max
import java.security.KeyStoreException
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
import javax.security.auth.login.AppConfigurationEntry import javax.security.auth.login.AppConfigurationEntry
import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED
@ -58,7 +66,9 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
private val maxMessageSize: Int, private val maxMessageSize: Int,
private val journalBufferTimeout : Int? = null, private val journalBufferTimeout : Int? = null,
private val threadPoolName: String = "ArtemisServer", private val threadPoolName: String = "ArtemisServer",
private val trace: Boolean = false) : ArtemisBroker, SingletonSerializeAsToken() { private val trace: Boolean = false,
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON,
private val remotingThreads: Int? = null) : ArtemisBroker, SingletonSerializeAsToken() {
companion object { companion object {
private val log = contextLogger() private val log = contextLogger()
} }
@ -92,7 +102,7 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
override val started: Boolean override val started: Boolean
get() = activeMQServer.isStarted get() = activeMQServer.isStarted
@Throws(IOException::class, AddressBindingException::class, KeyStoreException::class) @Suppress("ThrowsCount")
private fun configureAndStartServer() { private fun configureAndStartServer() {
val artemisConfig = createArtemisConfig() val artemisConfig = createArtemisConfig()
val securityManager = createArtemisSecurityManager() val securityManager = createArtemisSecurityManager()
@ -132,11 +142,23 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
// The transaction cache is configurable, and drives other cache sizes. // The transaction cache is configurable, and drives other cache sizes.
globalMaxSize = max(config.transactionCacheSizeBytes, 10L * maxMessageSize) globalMaxSize = max(config.transactionCacheSizeBytes, 10L * maxMessageSize)
val revocationMode = if (config.crlCheckArtemisServer) {
if (config.crlCheckSoftFail) RevocationConfig.Mode.SOFT_FAIL else RevocationConfig.Mode.HARD_FAIL
} else {
RevocationConfig.Mode.OFF
}
val trustManagerFactory = trustManagerFactoryWithRevocation(
config.p2pSslOptions.trustStore.get(),
RevocationConfigImpl(revocationMode),
distPointCrlSource
)
addAcceptorConfiguration(p2pAcceptorTcpTransport( addAcceptorConfiguration(p2pAcceptorTcpTransport(
NetworkHostAndPort(messagingServerAddress.host, messagingServerAddress.port), NetworkHostAndPort(messagingServerAddress.host, messagingServerAddress.port),
config.p2pSslOptions, config.p2pSslOptions,
trustManagerFactory,
threadPoolName = threadPoolName, threadPoolName = threadPoolName,
trace = trace trace = trace,
remotingThreads = remotingThreads
)) ))
// Enable built in message deduplication. Note we still have to do our own as the delayed commits // Enable built in message deduplication. Note we still have to do our own as the delayed commits
// and our own definition of commit mean that the built in deduplication cannot remove all duplicates. // and our own definition of commit mean that the built in deduplication cannot remove all duplicates.

View File

@ -10,6 +10,7 @@ import io.netty.handler.ssl.SslHandshakeTimeoutException
import net.corda.core.internal.declaredField import net.corda.core.internal.declaredField
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.ArtemisTcpTransport import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor
import org.apache.activemq.artemis.api.core.BaseInterceptor import org.apache.activemq.artemis.api.core.BaseInterceptor
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor
import org.apache.activemq.artemis.core.server.balancing.RedirectHandler import org.apache.activemq.artemis.core.server.balancing.RedirectHandler
@ -19,14 +20,18 @@ import org.apache.activemq.artemis.spi.core.remoting.Acceptor
import org.apache.activemq.artemis.spi.core.remoting.AcceptorFactory import org.apache.activemq.artemis.spi.core.remoting.AcceptorFactory
import org.apache.activemq.artemis.spi.core.remoting.BufferHandler import org.apache.activemq.artemis.spi.core.remoting.BufferHandler
import org.apache.activemq.artemis.spi.core.remoting.ServerConnectionLifeCycleListener import org.apache.activemq.artemis.spi.core.remoting.ServerConnectionLifeCycleListener
import org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactoryProvider
import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactoryProvider
import org.apache.activemq.artemis.utils.ConfigurationHelper import org.apache.activemq.artemis.utils.ConfigurationHelper
import org.apache.activemq.artemis.utils.actors.OrderedExecutor import org.apache.activemq.artemis.utils.actors.OrderedExecutor
import java.net.SocketAddress
import java.nio.channels.ClosedChannelException import java.nio.channels.ClosedChannelException
import java.time.Duration import java.time.Duration
import java.util.concurrent.Executor import java.util.concurrent.Executor
import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.ScheduledExecutorService
import java.util.regex.Pattern import java.util.regex.Pattern
import javax.net.ssl.SSLEngine import javax.net.ssl.SSLEngine
import javax.net.ssl.SSLPeerUnverifiedException
@Suppress("unused") // Used via reflection in ArtemisTcpTransport @Suppress("unused") // Used via reflection in ArtemisTcpTransport
class NodeNettyAcceptorFactory : AcceptorFactory { class NodeNettyAcceptorFactory : AcceptorFactory {
@ -55,9 +60,16 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
{ {
companion object { companion object {
private val defaultThreadPoolNamePattern = Pattern.compile("""Thread-(\d+) \(activemq-netty-threads\)""") private val defaultThreadPoolNamePattern = Pattern.compile("""Thread-(\d+) \(activemq-netty-threads\)""")
init {
// Make sure Artemis isn't using another (Open)SSLContextFactory
check(SSLContextFactoryProvider.getSSLContextFactory() is NodeSSLContextFactory)
check(OpenSSLContextFactoryProvider.getOpenSSLContextFactory() is NodeOpenSSLContextFactory)
}
} }
private val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "NodeNettyAcceptor", configuration) private val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "NodeNettyAcceptor", configuration)
private val sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
private val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration) private val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration)
@Synchronized @Synchronized
@ -71,11 +83,17 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
} }
} }
@Synchronized
override fun stop() {
super.stop()
sslDelegatedTaskExecutor.shutdown()
}
@Synchronized @Synchronized
override fun getSslHandler(alloc: ByteBufAllocator?, peerHost: String?, peerPort: Int): SslHandler { override fun getSslHandler(alloc: ByteBufAllocator?, peerHost: String?, peerPort: Int): SslHandler {
applyThreadPoolName() applyThreadPoolName()
val engine = super.getSslHandler(alloc, peerHost, peerPort).engine() val engine = super.getSslHandler(alloc, peerHost, peerPort).engine()
val sslHandler = NodeAcceptorSslHandler(engine, trace) val sslHandler = NodeAcceptorSslHandler(engine, sslDelegatedTaskExecutor, trace)
val handshakeTimeout = configuration[ArtemisTcpTransport.SSL_HANDSHAKE_TIMEOUT_NAME] as Duration? val handshakeTimeout = configuration[ArtemisTcpTransport.SSL_HANDSHAKE_TIMEOUT_NAME] as Duration?
if (handshakeTimeout != null) { if (handshakeTimeout != null) {
sslHandler.handshakeTimeoutMillis = handshakeTimeout.toMillis() sslHandler.handshakeTimeoutMillis = handshakeTimeout.toMillis()
@ -95,13 +113,15 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
} }
private class NodeAcceptorSslHandler(engine: SSLEngine, private val trace: Boolean) : SslHandler(engine) { private class NodeAcceptorSslHandler(engine: SSLEngine,
delegatedTaskExecutor: Executor,
private val trace: Boolean) : SslHandler(engine, delegatedTaskExecutor) {
companion object { companion object {
private val logger = contextLogger() private val logger = contextLogger()
} }
override fun handlerAdded(ctx: ChannelHandlerContext) { override fun handlerAdded(ctx: ChannelHandlerContext) {
logHandshake() logHandshake(ctx.channel().remoteAddress())
super.handlerAdded(ctx) super.handlerAdded(ctx)
// Unfortunately NettyAcceptor does not let us add extra child handlers, so we have to add our logger this way. // Unfortunately NettyAcceptor does not let us add extra child handlers, so we have to add our logger this way.
if (trace) { if (trace) {
@ -109,17 +129,22 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
} }
} }
private fun logHandshake() { private fun logHandshake(remoteAddress: SocketAddress) {
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
handshakeFuture().addListener { handshakeFuture().addListener {
val duration = System.currentTimeMillis() - start val duration = System.currentTimeMillis() - start
val peer = try {
engine().session.peerPrincipal
} catch (e: SSLPeerUnverifiedException) {
remoteAddress
}
when { when {
it.isSuccess -> logger.info("SSL handshake completed in ${duration}ms with ${engine().session.peerPrincipal}") it.isSuccess -> logger.info("SSL handshake completed in ${duration}ms with $peer")
it.isCancelled -> logger.warn("SSL handshake cancelled after ${duration}ms") it.isCancelled -> logger.warn("SSL handshake cancelled after ${duration}ms with $peer")
else -> when (it.cause()) { else -> when (it.cause()) {
is ClosedChannelException -> logger.warn("SSL handshake closed early after ${duration}ms") is ClosedChannelException -> logger.warn("SSL handshake closed early after ${duration}ms with $peer")
is SslHandshakeTimeoutException -> logger.warn("SSL handshake timed out after ${duration}ms") is SslHandshakeTimeoutException -> logger.warn("SSL handshake timed out after ${duration}ms with $peer")
else -> logger.warn("SSL handshake failed after ${duration}ms", it.cause()) else -> logger.warn("SSL handshake failed after ${duration}ms with $peer", it.cause())
} }
} }
} }

View File

@ -0,0 +1,59 @@
package net.corda.node.services.messaging
import io.netty.handler.ssl.SslContext
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.SslProvider
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.TRUST_MANAGER_FACTORY_NAME
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.createAndInitSslContext
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultOpenSSLContextFactory
import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultSSLContextFactory
import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextConfig
import java.nio.file.Paths
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
class NodeSSLContextFactory : DefaultSSLContextFactory() {
override fun getSSLContext(config: SSLContextConfig, additionalOpts: Map<String, Any>): SSLContext {
val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?
return if (trustManagerFactory != null) {
createAndInitSslContext(loadKeyManagerFactory(config), trustManagerFactory)
} else {
super.getSSLContext(config, additionalOpts)
}
}
override fun getPriority(): Int {
// We make sure this factory is the one that's chosen, so any sufficiently large value will do.
return 15
}
}
class NodeOpenSSLContextFactory : DefaultOpenSSLContextFactory() {
override fun getServerSslContext(config: SSLContextConfig, additionalOpts: Map<String, Any>): SslContext {
val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?
return if (trustManagerFactory != null) {
SslContextBuilder
.forServer(loadKeyManagerFactory(config))
.sslProvider(SslProvider.OPENSSL)
.trustManager(trustManagerFactory)
.build()
} else {
super.getServerSslContext(config, additionalOpts)
}
}
override fun getPriority(): Int {
// We make sure this factory is the one that's chosen, so any sufficiently large value will do.
return 15
}
}
private fun loadKeyManagerFactory(config: SSLContextConfig): KeyManagerFactory {
val keyStore = CertificateStore.fromFile(Paths.get(config.keystorePath), config.keystorePassword, config.keystorePassword, false)
return keyManagerFactory(keyStore)
}

View File

@ -30,7 +30,7 @@ internal class RpcBrokerConfiguration(baseDirectory: Path, maxMessageSize: Int,
setDirectories(baseDirectory) setDirectories(baseDirectory)
val acceptorConfigurationsSet = mutableSetOf( val acceptorConfigurationsSet = mutableSetOf(
rpcAcceptorTcpTransport(address, sslOptions, useSsl) rpcAcceptorTcpTransport(address, sslOptions, enableSSL = useSsl)
) )
adminAddress?.let { adminAddress?.let {
acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration) acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration)

View File

@ -0,0 +1 @@
net.corda.node.services.messaging.NodeOpenSSLContextFactory

View File

@ -0,0 +1 @@
net.corda.node.services.messaging.NodeSSLContextFactory

View File

@ -124,7 +124,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
bobNode.internals.disableDBCloseOnStop() bobNode.internals.disableDBCloseOnStop()
bobNode.database.transaction { bobNode.database.transaction {
VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, cashIssuer) VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, cashIssuer, atMostThisManyStates = 10)
} }
val alicesFakePaper = aliceNode.database.transaction { val alicesFakePaper = aliceNode.database.transaction {
@ -233,7 +233,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
val issuer = bank.ref(1, 2, 3) val issuer = bank.ref(1, 2, 3)
bobNode.database.transaction { bobNode.database.transaction {
VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, issuer) VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, issuer, atMostThisManyStates = 10)
} }
val alicesFakePaper = aliceNode.database.transaction { val alicesFakePaper = aliceNode.database.transaction {
fillUpForSeller(false, issuer, alice, fillUpForSeller(false, issuer, alice,

View File

@ -244,7 +244,6 @@ class FlowSoftLocksTests {
100.DOLLARS, 100.DOLLARS,
bankNode.services, bankNode.services,
thisManyStates, thisManyStates,
thisManyStates,
cashIssuer cashIssuer
) )
} }

View File

@ -20,14 +20,13 @@ import net.corda.finance.*
import net.corda.finance.contracts.CommercialPaper import net.corda.finance.contracts.CommercialPaper
import net.corda.finance.contracts.Commodity import net.corda.finance.contracts.Commodity
import net.corda.finance.contracts.DealState import net.corda.finance.contracts.DealState
import net.corda.finance.workflows.asset.selection.AbstractCashSelection
import net.corda.finance.contracts.asset.Cash import net.corda.finance.contracts.asset.Cash
import net.corda.finance.schemas.CashSchemaV1 import net.corda.finance.schemas.CashSchemaV1
import net.corda.finance.schemas.CashSchemaV1.PersistentCashState
import net.corda.finance.schemas.CommercialPaperSchemaV1 import net.corda.finance.schemas.CommercialPaperSchemaV1
import net.corda.finance.test.SampleCashSchemaV2 import net.corda.finance.test.SampleCashSchemaV2
import net.corda.finance.test.SampleCashSchemaV3 import net.corda.finance.test.SampleCashSchemaV3
import net.corda.finance.workflows.CommercialPaperUtils import net.corda.finance.workflows.CommercialPaperUtils
import net.corda.finance.workflows.asset.selection.AbstractCashSelection
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.nodeapi.internal.persistence.DatabaseTransaction
@ -197,8 +196,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
} }
protected fun consumeCash(amount: Amount<Currency>) = vaultFiller.consumeCash(amount, CHARLIE) protected fun consumeCash(amount: Amount<Currency>) = vaultFiller.consumeCash(amount, CHARLIE)
private fun setUpDb(_database: CordaPersistence, delay: Long = 0) {
_database.transaction { private fun setUpDb(database: CordaPersistence, delay: Long = 0) {
database.transaction {
// create new states // create new states
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 10, DUMMY_CASH_ISSUER) vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 10, DUMMY_CASH_ISSUER)
val linearStatesXYZ = vaultFiller.fillWithSomeTestLinearStates(1, "XYZ") val linearStatesXYZ = vaultFiller.fillWithSomeTestLinearStates(1, "XYZ")
@ -444,7 +444,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.state.data.linearNumber }.sortedBy { it.ref.txhash }.sortedBy { it.ref.index }).isEqualTo(allStates) Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.state.data.linearNumber }.sortedBy { it.ref.txhash }.sortedBy { it.ref.index }).isEqualTo(allStates)
} }
(1..3).forEach { repeat(3) {
val newAllStates = vaultService.queryBy<DummyLinearContract.State>(sorting = sorting, criteria = criteria).states val newAllStates = vaultService.queryBy<DummyLinearContract.State>(sorting = sorting, criteria = criteria).states
assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates) assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates)
assertThat(newAllStates).containsExactlyElementsOf(allStates) assertThat(newAllStates).containsExactlyElementsOf(allStates)
@ -485,7 +485,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.ref.txhash }.sortedByDescending { it.ref.index }).isEqualTo(allStates) Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.ref.txhash }.sortedByDescending { it.ref.index }).isEqualTo(allStates)
} }
(1..3).forEach { repeat(3) {
val newAllStates = vaultService.queryBy<DummyLinearContract.State>(sorting = sorting, criteria = criteria).states val newAllStates = vaultService.queryBy<DummyLinearContract.State>(sorting = sorting, criteria = criteria).states
assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates) assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates)
assertThat(newAllStates).containsExactlyElementsOf(allStates) assertThat(newAllStates).containsExactlyElementsOf(allStates)
@ -638,7 +638,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
} }
val sorted = results.states.sortedBy { it.ref.toString() } val sorted = results.states.sortedBy { it.ref.toString() }
assertThat(results.states).isEqualTo(sorted) assertThat(results.states).isEqualTo(sorted)
assertThat(results.states).allSatisfy { !consumed.contains(it.ref.txhash) } assertThat(results.states).allSatisfy { assertThat(consumed).doesNotContain(it.ref.txhash) }
} }
} }
@ -1537,7 +1537,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")) vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789"))
// count fungible assets // count fungible assets
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count) val countCriteria = VaultCustomQueryCriteria(count)
val fungibleStateCount = vaultService.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long val fungibleStateCount = vaultService.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
assertThat(fungibleStateCount).isEqualTo(10L) assertThat(fungibleStateCount).isEqualTo(10L)
@ -1563,7 +1563,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
// count fungible assets // count fungible assets
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) val countCriteria = VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val fungibleStateCount = vaultService.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long val fungibleStateCount = vaultService.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
assertThat(fungibleStateCount).isEqualTo(10L) assertThat(fungibleStateCount).isEqualTo(10L)
@ -1583,7 +1583,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// UNCONSUMED states (default) // UNCONSUMED states (default)
// count fungible assets // count fungible assets
val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED) val countCriteriaUnconsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED)
val fungibleStateCountUnconsumed = vaultService.queryBy<FungibleAsset<*>>(countCriteriaUnconsumed).otherResults.single() as Long val fungibleStateCountUnconsumed = vaultService.queryBy<FungibleAsset<*>>(countCriteriaUnconsumed).otherResults.single() as Long
assertThat(fungibleStateCountUnconsumed.toInt()).isEqualTo(10 - cashUpdates.consumed.size + cashUpdates.produced.size) assertThat(fungibleStateCountUnconsumed.toInt()).isEqualTo(10 - cashUpdates.consumed.size + cashUpdates.produced.size)
@ -1598,7 +1598,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// CONSUMED states // CONSUMED states
// count fungible assets // count fungible assets
val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED) val countCriteriaConsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED)
val fungibleStateCountConsumed = vaultService.queryBy<FungibleAsset<*>>(countCriteriaConsumed).otherResults.single() as Long val fungibleStateCountConsumed = vaultService.queryBy<FungibleAsset<*>>(countCriteriaConsumed).otherResults.single() as Long
assertThat(fungibleStateCountConsumed.toInt()).isEqualTo(cashUpdates.consumed.size) assertThat(fungibleStateCountConsumed.toInt()).isEqualTo(cashUpdates.consumed.size)
@ -1622,7 +1622,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val start = TODAY val start = TODAY
val end = TODAY.plus(30, ChronoUnit.DAYS) val end = TODAY.plus(30, ChronoUnit.DAYS)
val recordedBetweenExpression = TimeCondition( val recordedBetweenExpression = TimeCondition(
QueryCriteria.TimeInstantType.RECORDED, TimeInstantType.RECORDED,
ColumnPredicate.Between(start, end)) ColumnPredicate.Between(start, end))
val criteria = VaultQueryCriteria(timeCondition = recordedBetweenExpression) val criteria = VaultQueryCriteria(timeCondition = recordedBetweenExpression)
val results = vaultService.queryBy<ContractState>(criteria) val results = vaultService.queryBy<ContractState>(criteria)
@ -1632,7 +1632,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// Future // Future
val startFuture = TODAY.plus(1, ChronoUnit.DAYS) val startFuture = TODAY.plus(1, ChronoUnit.DAYS)
val recordedBetweenExpressionFuture = TimeCondition( val recordedBetweenExpressionFuture = TimeCondition(
QueryCriteria.TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end)) TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end))
val criteriaFuture = VaultQueryCriteria(timeCondition = recordedBetweenExpressionFuture) val criteriaFuture = VaultQueryCriteria(timeCondition = recordedBetweenExpressionFuture)
assertThat(vaultService.queryBy<ContractState>(criteriaFuture).states).isEmpty() assertThat(vaultService.queryBy<ContractState>(criteriaFuture).states).isEmpty()
} }
@ -1648,7 +1648,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
consumeCash(100.DOLLARS) consumeCash(100.DOLLARS)
val asOfDateTime = TODAY val asOfDateTime = TODAY
val consumedAfterExpression = TimeCondition( val consumedAfterExpression = TimeCondition(
QueryCriteria.TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime)) TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime))
val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED, val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED,
timeCondition = consumedAfterExpression) timeCondition = consumedAfterExpression)
val results = vaultService.queryBy<ContractState>(criteria) val results = vaultService.queryBy<ContractState>(criteria)
@ -1705,6 +1705,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
} }
// pagination: invalid page size // pagination: invalid page size
@Suppress("INTEGER_OVERFLOW")
@Test(timeout=300_000) @Test(timeout=300_000)
fun `invalid page size`() { fun `invalid page size`() {
expectedEx.expect(VaultQueryException::class.java) expectedEx.expect(VaultQueryException::class.java)
@ -1712,8 +1713,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
database.transaction { database.transaction {
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 100, DUMMY_CASH_ISSUER) vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 100, DUMMY_CASH_ISSUER)
@Suppress("EXPECTED_CONDITION") val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, Integer.MAX_VALUE + 1) // overflow = -2147483648
val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, @Suppress("INTEGER_OVERFLOW") Integer.MAX_VALUE + 1) // overflow = -2147483648
val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL)
vaultService.queryBy<ContractState>(criteria, paging = pagingSpec) vaultService.queryBy<ContractState>(criteria, paging = pagingSpec)
} }
@ -1781,9 +1781,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
println("$index : $any") println("$index : $any")
} }
assertThat(results.otherResults.size).isEqualTo(402) assertThat(results.otherResults.size).isEqualTo(402)
val instants = results.otherResults.filter { it is Instant }.map { it as Instant } val instants = results.otherResults.filterIsInstance<Instant>()
assertThat(instants).isSorted assertThat(instants).isSorted
val longs = results.otherResults.filter { it is Long }.map { it as Long } val longs = results.otherResults.filterIsInstance<Long>()
assertThat(longs.size).isEqualTo(201) assertThat(longs.size).isEqualTo(201)
assertThat(instants.size).isEqualTo(201) assertThat(instants.size).isEqualTo(201)
assertThat(longs.sum()).isEqualTo(20100L) assertThat(longs.sum()).isEqualTo(20100L)
@ -1911,8 +1911,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
fun `LinearStateQueryCriteria returns empty resultset without errors if there is an empty list after the 'in' clause`() { fun `LinearStateQueryCriteria returns empty resultset without errors if there is an empty list after the 'in' clause`() {
database.transaction { database.transaction {
val uid = UniqueIdentifier("999") val uid = UniqueIdentifier("999")
vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, uniqueIdentifier = uid) vaultFiller.fillWithSomeTestLinearStates(txCount = 1, uniqueIdentifier = uid)
vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, externalId = "1234") vaultFiller.fillWithSomeTestLinearStates(txCount = 1, externalId = "1234")
val uuidCriteria = LinearStateQueryCriteria(uuid = listOf(uid.id)) val uuidCriteria = LinearStateQueryCriteria(uuid = listOf(uid.id))
val externalIdCriteria = LinearStateQueryCriteria(externalId = listOf("1234")) val externalIdCriteria = LinearStateQueryCriteria(externalId = listOf("1234"))
@ -2061,6 +2061,26 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
} }
} }
@Test(timeout = 300_000)
fun `unconsumed states which are globally unordered across multiple transactions sorted by custom attribute`() {
val linearNumbers = Array(2) { LongArray(2) }
// Make sure states from the same transaction are not given consecutive linear numbers.
linearNumbers[0][0] = 1L
linearNumbers[0][1] = 3L
linearNumbers[1][0] = 2L
linearNumbers[1][1] = 4L
val results = database.transaction {
vaultFiller.fillWithTestStates(txCount = 2, statesPerTx = 2) { participantsToUse, txIndex, stateIndex ->
DummyLinearContract.State(participants = participantsToUse, linearNumber = linearNumbers[txIndex][stateIndex])
}
val sortColumn = Sort.SortColumn(SortAttribute.Custom(DummyLinearStateSchemaV1.PersistentDummyLinearState::class.java, "linearNumber"))
vaultService.queryBy<DummyLinearContract.State>(VaultQueryCriteria(), sorting = Sort(setOf(sortColumn)))
}
assertThat(results.states.map { it.state.data.linearNumber }).isEqualTo(listOf(1L, 2L, 3L, 4L))
}
@Test(timeout=300_000) @Test(timeout=300_000)
fun `return consumed linear states for a given linear id`() { fun `return consumed linear states for a given linear id`() {
database.transaction { database.transaction {
@ -2390,7 +2410,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
services.recordTransactions(commercialPaper2) services.recordTransactions(commercialPaper2)
val ccyIndex = builder { CommercialPaperSchemaV1.PersistentCommercialPaperState::currency.equal(USD.currencyCode) } val ccyIndex = builder { CommercialPaperSchemaV1.PersistentCommercialPaperState::currency.equal(USD.currencyCode) }
val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) val criteria1 = VaultCustomQueryCriteria(ccyIndex)
val result = vaultService.queryBy<CommercialPaper.State>(criteria1) val result = vaultService.queryBy<CommercialPaper.State>(criteria1)
@ -2433,9 +2453,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val maturityIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::maturity.greaterThanOrEqual(TEST_TX_TIME + 30.days) val maturityIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::maturity.greaterThanOrEqual(TEST_TX_TIME + 30.days)
val faceValueIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::faceValue.greaterThanOrEqual(10000L) val faceValueIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::faceValue.greaterThanOrEqual(10000L)
val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) val criteria1 = VaultCustomQueryCriteria(ccyIndex)
val criteria2 = QueryCriteria.VaultCustomQueryCriteria(maturityIndex) val criteria2 = VaultCustomQueryCriteria(maturityIndex)
val criteria3 = QueryCriteria.VaultCustomQueryCriteria(faceValueIndex) val criteria3 = VaultCustomQueryCriteria(faceValueIndex)
vaultService.queryBy<CommercialPaper.State>(criteria1.and(criteria3).and(criteria2)) vaultService.queryBy<CommercialPaper.State>(criteria1.and(criteria3).and(criteria2))
} }
@ -2458,8 +2478,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val generalCriteria = VaultQueryCriteria(Vault.StateStatus.ALL) val generalCriteria = VaultQueryCriteria(Vault.StateStatus.ALL)
val results = builder { val results = builder {
val currencyIndex = PersistentCashState::currency.equal(USD.currencyCode) val currencyIndex = CashSchemaV1.PersistentCashState::currency.equal(USD.currencyCode)
val quantityIndex = PersistentCashState::pennies.greaterThanOrEqual(10L) val quantityIndex = CashSchemaV1.PersistentCashState::pennies.greaterThanOrEqual(10L)
val customCriteria1 = VaultCustomQueryCriteria(currencyIndex) val customCriteria1 = VaultCustomQueryCriteria(currencyIndex)
val customCriteria2 = VaultCustomQueryCriteria(quantityIndex) val customCriteria2 = VaultCustomQueryCriteria(quantityIndex)
@ -2710,7 +2730,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// Enrich and override QueryCriteria with additional default attributes (such as soft locks) // Enrich and override QueryCriteria with additional default attributes (such as soft locks)
val enrichedCriteria = VaultQueryCriteria(contractStateTypes = setOf(DealState::class.java), // enrich val enrichedCriteria = VaultQueryCriteria(contractStateTypes = setOf(DealState::class.java), // enrich
softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())), softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())),
status = Vault.StateStatus.UNCONSUMED) // override status = Vault.StateStatus.UNCONSUMED) // override
// Sorting // Sorting
val sortAttribute = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF) val sortAttribute = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF)
@ -3056,7 +3076,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
assertThat(snapshot.states).hasSize(0) assertThat(snapshot.states).hasSize(0)
val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states
this.session.flush() this.session.flush()
vaultFiller.consumeLinearStates(states.toList()) vaultFiller.consumeStates(states)
updates updates
} }
@ -3079,7 +3099,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
assertThat(snapshot.states).hasSize(0) assertThat(snapshot.states).hasSize(0)
val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states
this.session.flush() this.session.flush()
vaultFiller.consumeLinearStates(states.toList()) vaultFiller.consumeStates(states)
updates updates
} }
@ -3102,7 +3122,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
assertThat(snapshot.states).hasSize(0) assertThat(snapshot.states).hasSize(0)
val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states
this.session.flush() this.session.flush()
vaultFiller.consumeLinearStates(states.toList()) vaultFiller.consumeStates(states)
updates updates
} }

View File

@ -1,30 +1,50 @@
@file:Suppress("UNUSED_PARAMETER")
@file:JvmName("TestUtils") @file:JvmName("TestUtils")
@file:Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod", "LongParameterList")
package net.corda.testing.core package net.corda.testing.core
import net.corda.core.contracts.PartyAndReference import net.corda.core.contracts.PartyAndReference
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
import net.corda.core.crypto.* import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignatureScheme
import net.corda.core.crypto.toStringShort
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.toX500Name
import net.corda.core.internal.unspecifiedCountry import net.corda.core.internal.unspecifiedCountry
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.millis import net.corda.core.utilities.millis
import net.corda.core.utilities.minutes
import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.nodeapi.internal.createDevNodeCa import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.CertificateType import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames
import net.corda.coretesting.internal.DEV_ROOT_CA import org.bouncycastle.asn1.x509.CRLReason
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.ExtensionsGenerator
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.IssuingDistributionPoint
import org.bouncycastle.cert.jcajce.JcaX509CRLConverter
import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import java.math.BigInteger import java.math.BigInteger
import java.net.URI
import java.security.KeyPair import java.security.KeyPair
import java.security.PublicKey import java.security.PublicKey
import java.security.cert.X509CRL
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.time.Duration import java.time.Duration
import java.time.Instant import java.time.Instant
import java.util.*
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.fail import kotlin.test.fail
@ -109,6 +129,44 @@ fun getTestPartyAndCertificate(name: CordaX500Name, publicKey: PublicKey): Party
return getTestPartyAndCertificate(Party(name, publicKey)) return getTestPartyAndCertificate(Party(name, publicKey))
} }
fun createCRL(issuer: CertificateAndKeyPair,
revokedCerts: List<X509Certificate>,
issuingDistPoint: URI? = null,
thisUpdate: Instant = Instant.now(),
nextUpdate: Instant = thisUpdate + 5.minutes,
indirect: Boolean = false,
revocationDate: Instant = thisUpdate,
crlReason: Int = CRLReason.keyCompromise,
signatureAlgorithm: String = "SHA256withECDSA"): X509CRL {
val builder = JcaX509v2CRLBuilder(issuer.certificate.subjectX500Principal, Date.from(thisUpdate))
val extensionUtils = JcaX509ExtensionUtils()
builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(issuer.certificate))
// This is required and needs to match the certificate settings with respect to being indirect
builder.addExtension(
Extension.issuingDistributionPoint,
true,
IssuingDistributionPoint(
issuingDistPoint?.let { DistributionPointName(toGeneralNames(it.toString(), GeneralName.uniformResourceIdentifier)) },
indirect,
false
)
)
builder.setNextUpdate(Date.from(nextUpdate))
for (revokedCert in revokedCerts) {
val extensionsGenerator = ExtensionsGenerator()
extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(crlReason))
// Certificate issuer is required for indirect CRL
extensionsGenerator.addExtension(
Extension.certificateIssuer,
true,
GeneralNames(GeneralName(revokedCert.issuerX500Principal.toX500Name()))
)
builder.addCRLEntry(revokedCert.serialNumber, Date.from(revocationDate), extensionsGenerator.generate())
}
val bcProvider = Crypto.findProvider("BC")
val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(bcProvider).build(issuer.keyPair.private)
return JcaX509CRLConverter().setProvider(bcProvider).getCRL(builder.build(signer))
}
private val count = AtomicInteger(0) private val count = AtomicInteger(0)
/** /**
@ -188,7 +246,6 @@ fun NodeInfo.singleIdentity(): Party = singleIdentityAndCert().party
* The above will test our expectation that the getWaitingFlows action was executed successfully considering * The above will test our expectation that the getWaitingFlows action was executed successfully considering
* that it may take a few hundreds of milliseconds for the flow state machine states to settle. * that it may take a few hundreds of milliseconds for the flow state machine states to settle.
*/ */
@Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod")
fun <T> executeTest( fun <T> executeTest(
timeout: Duration, timeout: Duration,
cleanup: (() -> Unit)? = null, cleanup: (() -> Unit)? = null,

View File

@ -4,30 +4,26 @@ package net.corda.testing.node.internal.network
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.internal.CertRole import net.corda.core.internal.CertRole
import net.corda.core.internal.toX500Name
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.core.utilities.minutes import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA
import net.corda.coretesting.internal.DEV_ROOT_CA import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.ContentSignerBuilder import net.corda.nodeapi.internal.crypto.ContentSignerBuilder
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames
import net.corda.nodeapi.internal.crypto.certificateType import net.corda.nodeapi.internal.crypto.certificateType
import net.corda.nodeapi.internal.crypto.toJca import net.corda.nodeapi.internal.crypto.toJca
import org.bouncycastle.asn1.x500.X500Name import net.corda.testing.core.createCRL
import org.bouncycastle.asn1.x509.CRLDistPoint import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.DistributionPoint import org.bouncycastle.asn1.x509.DistributionPoint
import org.bouncycastle.asn1.x509.DistributionPointName import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.GeneralName import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.IssuingDistributionPoint
import org.bouncycastle.asn1.x509.ReasonFlags
import org.bouncycastle.cert.jcajce.JcaX509CRLConverter
import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.ServerConnector import org.eclipse.jetty.server.ServerConnector
import org.eclipse.jetty.server.handler.HandlerCollection import org.eclipse.jetty.server.handler.HandlerCollection
@ -36,11 +32,12 @@ import org.eclipse.jetty.servlet.ServletHolder
import org.glassfish.jersey.server.ResourceConfig import org.glassfish.jersey.server.ResourceConfig
import org.glassfish.jersey.servlet.ServletContainer import org.glassfish.jersey.servlet.ServletContainer
import java.io.Closeable import java.io.Closeable
import java.math.BigInteger
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.net.URI
import java.security.KeyPair import java.security.KeyPair
import java.security.cert.X509CRL import java.security.cert.X509CRL
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.time.Duration
import java.util.* import java.util.*
import javax.security.auth.x500.X500Principal import javax.security.auth.x500.X500Principal
import javax.ws.rs.GET import javax.ws.rs.GET
@ -51,7 +48,7 @@ import kotlin.collections.ArrayList
class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
companion object { companion object {
private const val SIGNATURE_ALGORITHM = "SHA256withECDSA" private val logger = contextLogger()
const val NODE_CRL = "node.crl" const val NODE_CRL = "node.crl"
const val FORBIDDEN_CRL = "forbidden.crl" const val FORBIDDEN_CRL = "forbidden.crl"
@ -72,8 +69,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
null null
) )
if (crlDistPoint != null) { if (crlDistPoint != null) {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint))) val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier))
val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(X500Name.getInstance(it.encoded))) } val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(it.toX500Name())) }
val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames) val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames)
builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint))) builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint)))
} }
@ -87,14 +84,17 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
} }
} }
val revokedNodeCerts: MutableList<BigInteger> = ArrayList() val revokedNodeCerts: MutableList<X509Certificate> = ArrayList()
val revokedIntermediateCerts: MutableList<BigInteger> = ArrayList() val revokedIntermediateCerts: MutableList<X509Certificate> = ArrayList()
val rootCa: CertificateAndKeyPair = DEV_ROOT_CA val rootCa: CertificateAndKeyPair = DEV_ROOT_CA
private lateinit var _intermediateCa: CertificateAndKeyPair private lateinit var _intermediateCa: CertificateAndKeyPair
val intermediateCa: CertificateAndKeyPair get() = _intermediateCa val intermediateCa: CertificateAndKeyPair get() = _intermediateCa
@Volatile
var delay: Duration? = null
val hostAndPort: NetworkHostAndPort val hostAndPort: NetworkHostAndPort
get() = server.connectors.mapNotNull { it as? ServerConnector } get() = server.connectors.mapNotNull { it as? ServerConnector }
.map { NetworkHostAndPort(it.host, it.localPort) } .map { NetworkHostAndPort(it.host, it.localPort) }
@ -106,7 +106,7 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
DEV_INTERMEDIATE_CA.certificate.withCrlDistPoint(rootCa.keyPair, "http://$hostAndPort/crl/$INTERMEDIATE_CRL"), DEV_INTERMEDIATE_CA.certificate.withCrlDistPoint(rootCa.keyPair, "http://$hostAndPort/crl/$INTERMEDIATE_CRL"),
DEV_INTERMEDIATE_CA.keyPair DEV_INTERMEDIATE_CA.keyPair
) )
println("Network management web services started on $hostAndPort") logger.info("Network management web services started on $hostAndPort")
} }
fun replaceNodeCertDistPoint(nodeCaCert: X509Certificate, fun replaceNodeCertDistPoint(nodeCaCert: X509Certificate,
@ -115,29 +115,20 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
return nodeCaCert.withCrlDistPoint(intermediateCa.keyPair, nodeCaCrlDistPoint, crlIssuer) return nodeCaCert.withCrlDistPoint(intermediateCa.keyPair, nodeCaCrlDistPoint, crlIssuer)
} }
fun createRevocationList(signatureAlgorithm: String, private fun createServerCRL(issuer: CertificateAndKeyPair,
ca: CertificateAndKeyPair, endpoint: String,
endpoint: String, indirect: Boolean,
indirect: Boolean, revokedCerts: List<X509Certificate>): X509CRL {
serialNumbers: List<BigInteger>): X509CRL { logger.info("Generating CRL for /$endpoint: ${revokedCerts.map { it.serialNumber }}")
println("Generating CRL for $endpoint") return createCRL(
val builder = JcaX509v2CRLBuilder(ca.certificate.subjectX500Principal, Date(System.currentTimeMillis() - 1.minutes.toMillis())) issuer,
val extensionUtils = JcaX509ExtensionUtils() revokedCerts,
builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(ca.certificate)) issuingDistPoint = URI("http://$hostAndPort/crl/$endpoint"),
val issuingDistPointName = GeneralName(GeneralName.uniformResourceIdentifier, "http://$hostAndPort/crl/$endpoint") indirect = indirect
// This is required and needs to match the certificate settings with respect to being indirect )
val issuingDistPoint = IssuingDistributionPoint(DistributionPointName(GeneralNames(issuingDistPointName)), indirect, false)
builder.addExtension(Extension.issuingDistributionPoint, true, issuingDistPoint)
builder.setNextUpdate(Date(System.currentTimeMillis() + 1.seconds.toMillis()))
serialNumbers.forEach {
builder.addCRLEntry(it, Date(System.currentTimeMillis() - 10.minutes.toMillis()), ReasonFlags.certificateHold)
}
val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(Crypto.findProvider("BC")).build(ca.keyPair.private)
return JcaX509CRLConverter().setProvider(Crypto.findProvider("BC")).getCRL(builder.build(signer))
} }
override fun close() { override fun close() {
println("Shutting down network management web services...")
server.stop() server.stop()
server.join() server.join()
} }
@ -159,8 +150,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
@Path(NODE_CRL) @Path(NODE_CRL)
@Produces("application/pkcs7-crl") @Produces("application/pkcs7-crl")
fun getNodeCRL(): Response { fun getNodeCRL(): Response {
return Response.ok(crlServer.createRevocationList( crlServer.delay?.toMillis()?.let(Thread::sleep)
SIGNATURE_ALGORITHM, return Response.ok(crlServer.createServerCRL(
crlServer.intermediateCa, crlServer.intermediateCa,
NODE_CRL, NODE_CRL,
false, false,
@ -179,8 +170,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
@Path(INTERMEDIATE_CRL) @Path(INTERMEDIATE_CRL)
@Produces("application/pkcs7-crl") @Produces("application/pkcs7-crl")
fun getIntermediateCRL(): Response { fun getIntermediateCRL(): Response {
return Response.ok(crlServer.createRevocationList( crlServer.delay?.toMillis()?.let(Thread::sleep)
SIGNATURE_ALGORITHM, return Response.ok(crlServer.createServerCRL(
crlServer.rootCa, crlServer.rootCa,
INTERMEDIATE_CRL, INTERMEDIATE_CRL,
false, false,
@ -192,11 +183,11 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
@Path(EMPTY_CRL) @Path(EMPTY_CRL)
@Produces("application/pkcs7-crl") @Produces("application/pkcs7-crl")
fun getEmptyCRL(): Response { fun getEmptyCRL(): Response {
return Response.ok(crlServer.createRevocationList( return Response.ok(crlServer.createServerCRL(
SIGNATURE_ALGORITHM,
crlServer.rootCa, crlServer.rootCa,
EMPTY_CRL, EMPTY_CRL,
true, emptyList() true,
emptyList()
).encoded).build() ).encoded).build()
} }
} }

View File

@ -42,6 +42,7 @@ import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.SchemaMigration import net.corda.nodeapi.internal.persistence.SchemaMigration
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.serialization.internal.amqp.AMQP_ENABLED import net.corda.serialization.internal.amqp.AMQP_ENABLED
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
@ -52,6 +53,8 @@ import java.io.IOException
import java.net.ServerSocket import java.net.ServerSocket
import java.nio.file.Path import java.nio.file.Path
import java.security.KeyPair import java.security.KeyPair
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.util.* import java.util.*
import java.util.jar.JarOutputStream import java.util.jar.JarOutputStream
import java.util.jar.Manifest import java.util.jar.Manifest
@ -147,6 +150,12 @@ fun p2pSslOptions(path: Path, name: CordaX500Name = CordaX500Name("MegaCorp", "L
return sslConfig return sslConfig
} }
fun fixedCrlSource(crls: Set<X509CRL>): CrlSource {
return object : CrlSource {
override fun fetch(certificate: X509Certificate): Set<X509CRL> = crls
}
}
/** This is the same as the deprecated [WireTransaction] c'tor but avoids the deprecation warning. */ /** This is the same as the deprecated [WireTransaction] c'tor but avoids the deprecation warning. */
@SuppressWarnings("LongParameterList") @SuppressWarnings("LongParameterList")
fun createWireTransaction(inputs: List<StateRef>, fun createWireTransaction(inputs: List<StateRef>,

View File

@ -1,6 +1,20 @@
@file:Suppress("LongParameterList")
package net.corda.testing.internal.vault package net.corda.testing.internal.vault
import net.corda.core.contracts.* import net.corda.core.contracts.Amount
import net.corda.core.contracts.AttachmentConstraint
import net.corda.core.contracts.AutomaticPlaceholderConstraint
import net.corda.core.contracts.BelongsToContract
import net.corda.core.contracts.CommandAndState
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.FungibleAsset
import net.corda.core.contracts.Issued
import net.corda.core.contracts.LinearState
import net.corda.core.contracts.PartyAndReference
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.TransactionState
import net.corda.core.contracts.UniqueIdentifier
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SignatureMetadata import net.corda.core.crypto.SignatureMetadata
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
@ -19,9 +33,7 @@ import net.corda.finance.contracts.asset.Cash
import net.corda.finance.contracts.asset.Obligation import net.corda.finance.contracts.asset.Obligation
import net.corda.finance.contracts.asset.OnLedgerAsset import net.corda.finance.contracts.asset.OnLedgerAsset
import net.corda.finance.workflows.asset.CashUtils import net.corda.finance.workflows.asset.CashUtils
import net.corda.testing.contracts.DummyContract
import net.corda.testing.contracts.DummyState import net.corda.testing.contracts.DummyState
import net.corda.testing.core.DummyCommandData
import net.corda.testing.core.TestIdentity import net.corda.testing.core.TestIdentity
import net.corda.testing.core.dummyCommand import net.corda.testing.core.dummyCommand
import net.corda.testing.core.singleIdentity import net.corda.testing.core.singleIdentity
@ -32,6 +44,7 @@ import java.time.Duration
import java.time.Instant import java.time.Instant
import java.time.Instant.now import java.time.Instant.now
import java.util.* import java.util.*
import kotlin.math.floor
/** /**
* The service hub should provide at least a key management service and a storage service. * The service hub should provide at least a key management service and a storage service.
@ -46,7 +59,7 @@ class VaultFiller @JvmOverloads constructor(
private val rngFactory: () -> Random = { Random(0L) }) { private val rngFactory: () -> Random = { Random(0L) }) {
companion object { companion object {
fun calculateRandomlySizedAmounts(howMuch: Amount<Currency>, min: Int, max: Int, rng: Random): LongArray { fun calculateRandomlySizedAmounts(howMuch: Amount<Currency>, min: Int, max: Int, rng: Random): LongArray {
val numSlots = min + Math.floor(rng.nextDouble() * (max - min)).toInt() val numSlots = min + floor(rng.nextDouble() * (max - min)).toInt()
val baseSize = howMuch.quantity / numSlots val baseSize = howMuch.quantity / numSlots
check(baseSize > 0) { baseSize } check(baseSize > 0) { baseSize }
@ -79,31 +92,18 @@ class VaultFiller @JvmOverloads constructor(
issuerServices: ServiceHub = services, issuerServices: ServiceHub = services,
participants: List<AbstractParty> = emptyList(), participants: List<AbstractParty> = emptyList(),
includeMe: Boolean = true): Vault<DealState> { includeMe: Boolean = true): Vault<DealState> {
val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey return fillWithTestStates(
val me = AnonymousParty(myKey) txCount = dealIds.size,
val participantsToUse = if (includeMe) participants.plus(me) else participants participants = participants,
includeMe = includeMe,
val transactions: List<SignedTransaction> = dealIds.map { services = issuerServices
// Issue a deal state ) { participantsToUse, txIndex, _ ->
val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { DummyDealContract.State(ref = dealIds[txIndex], participants = participantsToUse)
addOutputState(DummyDealContract.State(ref = it, participants = participantsToUse), DUMMY_DEAL_PROGRAM_ID)
addCommand(dummyCommand())
}
val stx = issuerServices.signInitialTransaction(dummyIssue)
return@map services.addSignature(stx, defaultNotary.publicKey)
} }
val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE
services.recordTransactions(statesToRecord, transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<DealState>(i) }
}
return Vault(states)
} }
@JvmOverloads @JvmOverloads
fun fillWithSomeTestLinearStates(numberToCreate: Int, fun fillWithSomeTestLinearStates(txCount: Int,
externalId: String? = null, externalId: String? = null,
participants: List<AbstractParty> = emptyList(), participants: List<AbstractParty> = emptyList(),
uniqueIdentifier: UniqueIdentifier? = null, uniqueIdentifier: UniqueIdentifier? = null,
@ -113,81 +113,41 @@ class VaultFiller @JvmOverloads constructor(
linearTimestamp: Instant = now(), linearTimestamp: Instant = now(),
constraint: AttachmentConstraint = AutomaticPlaceholderConstraint, constraint: AttachmentConstraint = AutomaticPlaceholderConstraint,
includeMe: Boolean = true): Vault<LinearState> { includeMe: Boolean = true): Vault<LinearState> {
val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey return fillWithTestStates(txCount, 1, participants, constraint, includeMe) { participantsToUse, _, _ ->
val me = AnonymousParty(myKey) DummyLinearContract.State(
val issuerKey = defaultNotary.keyPair linearId = uniqueIdentifier ?: UniqueIdentifier(externalId),
val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) participants = participantsToUse,
val participantsToUse = if (includeMe) participants.plus(me) else participants linearString = linearString,
val transactions: List<SignedTransaction> = (1..numberToCreate).map { linearNumber = linearNumber,
// Issue a Linear state linearBoolean = linearBoolean,
val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { linearTimestamp = linearTimestamp
addOutputState(DummyLinearContract.State( )
linearId = uniqueIdentifier ?: UniqueIdentifier(externalId),
participants = participantsToUse,
linearString = linearString,
linearNumber = linearNumber,
linearBoolean = linearBoolean,
linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID,
constraint = constraint)
addCommand(dummyCommand())
}
return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata)
} }
val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE
services.recordTransactions(statesToRecord, transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<LinearState>(i) }
}
return Vault(states)
} }
@JvmOverloads @JvmOverloads
fun fillWithSomeTestLinearAndDealStates(numberToCreate: Int, fun fillWithSomeTestLinearAndDealStates(txCount: Int,
externalId: String? = null, externalId: String? = null,
participants: List<AbstractParty> = emptyList(), participants: List<AbstractParty> = emptyList(),
linearString: String = "", linearString: String = "",
linearNumber: Long = 0L, linearNumber: Long = 0L,
linearBoolean: Boolean = false, linearBoolean: Boolean = false,
linearTimestamp: Instant = now()): Vault<LinearState> { linearTimestamp: Instant = now()): Vault<ContractState> {
val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey return fillWithTestStates(txCount, 2, participants) { participantsToUse, _, stateIndex ->
val me = AnonymousParty(myKey) when (stateIndex) {
val issuerKey = defaultNotary.keyPair 0 -> DummyLinearContract.State(
val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID)
val transactions: List<SignedTransaction> = (1..numberToCreate).map {
val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply {
// Issue a Linear state
addOutputState(DummyLinearContract.State(
linearId = UniqueIdentifier(externalId), linearId = UniqueIdentifier(externalId),
participants = participants.plus(me), participants = participantsToUse,
linearString = linearString, linearString = linearString,
linearNumber = linearNumber, linearNumber = linearNumber,
linearBoolean = linearBoolean, linearBoolean = linearBoolean,
linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID) linearTimestamp = linearTimestamp
// Issue a Deal state )
addOutputState(DummyDealContract.State(ref = "test ref", participants = participants.plus(me)), DUMMY_DEAL_PROGRAM_ID) else -> DummyDealContract.State(ref = "test ref", participants = participantsToUse)
addCommand(dummyCommand())
} }
return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata)
} }
services.recordTransactions(transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<LinearState>(i) }
}
return Vault(states)
} }
@JvmOverloads
fun fillWithSomeTestCash(howMuch: Amount<Currency>,
issuerServices: ServiceHub,
thisManyStates: Int,
issuedBy: PartyAndReference,
owner: AbstractParty? = null,
rng: Random? = null,
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT) = fillWithSomeTestCash(howMuch, issuerServices, thisManyStates, thisManyStates, issuedBy, owner, rng, statesToRecord)
/** /**
* Creates a random set of between (by default) 3 and 10 cash states that add up to the given amount and adds them * Creates a random set of between (by default) 3 and 10 cash states that add up to the given amount and adds them
* to the vault. This is intended for unit tests. By default the cash is owned by the legal * to the vault. This is intended for unit tests. By default the cash is owned by the legal
@ -196,14 +156,15 @@ class VaultFiller @JvmOverloads constructor(
* @param issuerServices service hub of the issuer node, which will be used to sign the transaction. * @param issuerServices service hub of the issuer node, which will be used to sign the transaction.
* @return a vault object that represents the generated states (it will NOT be the full vault from the service hub!). * @return a vault object that represents the generated states (it will NOT be the full vault from the service hub!).
*/ */
@JvmOverloads
fun fillWithSomeTestCash(howMuch: Amount<Currency>, fun fillWithSomeTestCash(howMuch: Amount<Currency>,
issuerServices: ServiceHub, issuerServices: ServiceHub,
atLeastThisManyStates: Int, atLeastThisManyStates: Int,
atMostThisManyStates: Int,
issuedBy: PartyAndReference, issuedBy: PartyAndReference,
owner: AbstractParty? = null, owner: AbstractParty? = null,
rng: Random? = null, rng: Random? = null,
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault<Cash.State> { statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT,
atMostThisManyStates: Int = atLeastThisManyStates): Vault<Cash.State> {
val amounts = calculateRandomlySizedAmounts(howMuch, atLeastThisManyStates, atMostThisManyStates, rng ?: rngFactory()) val amounts = calculateRandomlySizedAmounts(howMuch, atLeastThisManyStates, atMostThisManyStates, rng ?: rngFactory())
// We will allocate one state to one transaction, for simplicities sake. // We will allocate one state to one transaction, for simplicities sake.
val cash = Cash() val cash = Cash()
@ -212,39 +173,46 @@ class VaultFiller @JvmOverloads constructor(
cash.generateIssue(issuance, Amount(pennies, Issued(issuedBy, howMuch.token)), owner ?: services.myInfo.singleIdentity(), altNotary) cash.generateIssue(issuance, Amount(pennies, Issued(issuedBy, howMuch.token)), owner ?: services.myInfo.singleIdentity(), altNotary)
return@map issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) return@map issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey)
} }
services.recordTransactions(statesToRecord, transactions) return recordTransactions(transactions, statesToRecord)
// Get all the StateRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<Cash.State>(i) }
}
return Vault(states)
} }
/** /**
* Records a dummy state in the Vault (useful for creating random states when testing vault queries) * Records a dummy state in the Vault (useful for creating random states when testing vault queries)
*/ */
fun fillWithDummyState(participants: List<AbstractParty> = listOf(services.myInfo.singleIdentity())) : Vault<DummyState> { fun fillWithDummyState(participants: List<AbstractParty> = listOf(services.myInfo.singleIdentity())): Vault<DummyState> {
val outputState = TransactionState( return fillWithTestStates(participants = participants) { participantsToUse, _, _ ->
data = DummyState(Random().nextInt(), participants = participants), DummyState(Random().nextInt(), participants = participantsToUse)
contract = DummyContract.PROGRAM_ID, }
notary = defaultNotary.party
)
val participantKeys : List<PublicKey> = participants.map { it.owningKey }
val builder = TransactionBuilder()
.addOutputState(outputState)
.addCommand(DummyCommandData, participantKeys)
val stxn = services.signInitialTransaction(builder)
services.recordTransactions(stxn)
return Vault(setOf(stxn.tx.outRef(0)))
} }
/** fun <T : ContractState> fillWithTestStates(txCount: Int = 1,
* Puts together an issuance transaction for the specified amount that starts out being owned by the given pubkey. statesPerTx: Int = 1,
*/ participants: List<AbstractParty> = emptyList(),
fun generateCommoditiesIssue(tx: TransactionBuilder, amount: Amount<Issued<Commodity>>, owner: AbstractParty, notary: Party) constraint: AttachmentConstraint = AutomaticPlaceholderConstraint,
= OnLedgerAsset.generateIssue(tx, TransactionState(CommodityState(amount, owner), Obligation.PROGRAM_ID, notary), Obligation.Commands.Issue()) includeMe: Boolean = true,
services: ServiceHub = this.services,
genOutputState: (participantsToUse: List<AbstractParty>, txIndex: Int, stateIndex: Int) -> T): Vault<T> {
val issuerKey = defaultNotary.keyPair
val signatureMetadata = SignatureMetadata(
services.myInfo.platformVersion,
Crypto.findSignatureScheme(issuerKey.public).schemeNumberID
)
val participantsToUse = if (includeMe) {
participants + AnonymousParty(this.services.myInfo.chooseIdentity().owningKey)
} else {
participants
}
val transactions = Array(txCount) { txIndex ->
val builder = TransactionBuilder(notary = defaultNotary.party)
repeat(statesPerTx) { stateIndex ->
builder.addOutputState(genOutputState(participantsToUse, txIndex, stateIndex), constraint)
}
builder.addCommand(dummyCommand())
services.signInitialTransaction(builder).withAdditionalSignature(issuerKey, signatureMetadata)
}
val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE
return recordTransactions(transactions.asList(), statesToRecord)
}
/** /**
* *
@ -257,13 +225,16 @@ class VaultFiller @JvmOverloads constructor(
val me = AnonymousParty(myKey) val me = AnonymousParty(myKey)
val issuance = TransactionBuilder(null as Party?) val issuance = TransactionBuilder(null as Party?)
generateCommoditiesIssue(issuance, Amount(amount.quantity, Issued(issuedBy, amount.token)), me, altNotary) OnLedgerAsset.generateIssue(
issuance,
TransactionState(CommodityState(Amount(amount.quantity, Issued(issuedBy, amount.token)), me), Obligation.PROGRAM_ID, altNotary),
Obligation.Commands.Issue()
)
val transaction = issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) val transaction = issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey)
services.recordTransactions(transaction) return recordTransactions(listOf(transaction))
return Vault(setOf(transaction.tx.outRef(0)))
} }
private fun <T : LinearState> consume(states: List<StateAndRef<T>>) { fun consumeStates(states: Iterable<StateAndRef<*>>) {
// Create a txn consuming different contract types // Create a txn consuming different contract types
states.forEach { states.forEach {
val builder = TransactionBuilder(notary = altNotary).apply { val builder = TransactionBuilder(notary = altNotary).apply {
@ -300,10 +271,11 @@ class VaultFiller @JvmOverloads constructor(
} }
} }
fun consumeDeals(dealStates: List<StateAndRef<DealState>>) = consume(dealStates) fun consumeDeals(dealStates: List<StateAndRef<DealState>>) = consumeStates(dealStates)
fun consumeLinearStates(linearStates: List<StateAndRef<LinearState>>) = consume(linearStates) fun consumeLinearStates(linearStates: List<StateAndRef<LinearState>>) = consumeStates(linearStates)
fun evolveLinearStates(linearStates: List<StateAndRef<LinearState>>) = consumeAndProduce(linearStates) fun evolveLinearStates(linearStates: List<StateAndRef<LinearState>>) = consumeAndProduce(linearStates)
fun evolveLinearState(linearState: StateAndRef<LinearState>): StateAndRef<LinearState> = consumeAndProduce(linearState) fun evolveLinearState(linearState: StateAndRef<LinearState>): StateAndRef<LinearState> = consumeAndProduce(linearState)
/** /**
* Consume cash, sending any change to the default identity for this node. Only suitable for use in test scenarios, * Consume cash, sending any change to the default identity for this node. Only suitable for use in test scenarios,
* where nodes have a default identity. * where nodes have a default identity.
@ -319,6 +291,16 @@ class VaultFiller @JvmOverloads constructor(
services.recordTransactions(spendTx) services.recordTransactions(spendTx)
return update.getOrThrow(Duration.ofSeconds(3)) return update.getOrThrow(Duration.ofSeconds(3))
} }
private fun <T : ContractState> recordTransactions(transactions: Iterable<SignedTransaction>,
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault<T> {
services.recordTransactions(statesToRecord, transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<T>(i) }
}
return Vault(states)
}
} }
@ -344,4 +326,3 @@ data class CommodityState(
override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Obligation.Commands.Move(), copy(owner = newOwner)) override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Obligation.Commands.Move(), copy(owner = newOwner))
} }