Merge pull request #7381 from corda/shams-4.9-frwd-merge-a817218b

ENT-9806: 4.8 to 4.9 forward merge
This commit is contained in:
Rick Parker 2023-06-05 09:33:24 +01:00 committed by GitHub
commit d0f28a607f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 1246 additions and 951 deletions

View File

@ -0,0 +1,29 @@
package net.corda.coretests.crypto.internal
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.testing.core.createCRL
import org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import org.junit.Test
class ProviderMapTest {
// https://github.com/corda/corda/pull/3997
@Test(timeout = 300_000)
fun `verify CRL algorithms`() {
val crl = createCRL(
issuer = DEV_ROOT_CA,
revokedCerts = emptyList(),
signatureAlgorithm = "SHA256withECDSA"
)
// This should pass.
crl.verify(DEV_ROOT_CA.keyPair.public)
// Try changing the algorithm to EC will fail.
assertThatIllegalArgumentException().isThrownBy {
createCRL(
issuer = DEV_ROOT_CA,
revokedCerts = emptyList(),
signatureAlgorithm = "EC"
)
}.withMessage("Unknown signature type requested: EC")
}
}

View File

@ -64,7 +64,7 @@ interface ServicesForResolution {
/**
* Given a [Set] of [StateRef]'s loads the referenced transaction and looks up the specified output [ContractState].
*
* @throws TransactionResolutionException if [stateRef] points to a non-existent transaction.
* @throws TransactionResolutionException if any of the [stateRefs] point to a non-existent transaction.
*/
// TODO: future implementation to use a Vault state ref -> contract state BLOB table and perform single query bulk load
// as the existing transaction store will become encrypted at some point

View File

@ -21,14 +21,29 @@ import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.days
import net.corda.core.utilities.hours
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
import net.corda.coretesting.internal.NettyTestClient
import net.corda.coretesting.internal.NettyTestHandler
import net.corda.coretesting.internal.NettyTestServer
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_IDENTITY_SIGNATURE_SCHEME
import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME
import net.corda.nodeapi.internal.crypto.checkValidity
import net.corda.nodeapi.internal.crypto.getSupportedKey
import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore
import net.corda.nodeapi.internal.crypto.save
import net.corda.nodeapi.internal.crypto.toBc
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.crypto.x509Certificates
import net.corda.nodeapi.internal.installDevNodeCaCertPath
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
import net.corda.serialization.internal.AllWhitelist
import net.corda.serialization.internal.SerializationContextImpl
import net.corda.serialization.internal.SerializationFactoryImpl
@ -37,25 +52,16 @@ import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.TestIdentity
import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.coretesting.internal.NettyTestClient
import net.corda.coretesting.internal.NettyTestHandler
import net.corda.coretesting.internal.NettyTestServer
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.checkValidity
import net.corda.nodeapi.internal.crypto.getSupportedKey
import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore
import net.corda.nodeapi.internal.crypto.save
import net.corda.nodeapi.internal.crypto.toBc
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.crypto.x509Certificates
import net.corda.testing.internal.IS_OPENJ9
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.i2p.crypto.eddsa.EdDSAPrivateKey
import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.asn1.x509.*
import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.bouncycastle.jcajce.provider.asymmetric.edec.BCEdDSAPrivateKey
import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey
import org.junit.Assume
@ -74,10 +80,19 @@ import java.security.PrivateKey
import java.security.cert.CertPath
import java.security.cert.X509Certificate
import java.util.*
import javax.net.ssl.*
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLParameters
import javax.net.ssl.SSLServerSocket
import javax.net.ssl.SSLSocket
import javax.security.auth.x500.X500Principal
import kotlin.concurrent.thread
import kotlin.test.*
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue
import kotlin.test.fail
class X509UtilitiesTest {
private companion object {
@ -295,15 +310,10 @@ class X509UtilitiesTest {
sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val context = SSLContext.getInstance("TLS")
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get())
val trustManagers = trustMgrFactory.trustManagers
context.init(keyManagers, trustManagers, newSecureRandom())
@ -388,15 +398,8 @@ class X509UtilitiesTest {
sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore)
val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val trustManagerFactory = trustManagerFactory(sslConfig.trustStore.get())
val sslServerContext = SslContextBuilder
.forServer(keyManagerFactory)

View File

@ -1,16 +1,18 @@
@file:Suppress("LongParameterList")
package net.corda.nodeapi.internal
import net.corda.core.messaging.ClientRpcSslOptions
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.BrokerRpcSslOptions
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.config.SslConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants
import java.nio.file.Path
import javax.net.ssl.TrustManagerFactory
@Suppress("LongParameterList")
class ArtemisTcpTransport {
@ -23,6 +25,7 @@ class ArtemisTcpTransport {
val TLS_VERSIONS = listOf("TLSv1.2")
const val SSL_HANDSHAKE_TIMEOUT_NAME = "Corda-SSLHandshakeTimeout"
const val TRUST_MANAGER_FACTORY_NAME = "Corda-TrustManagerFactory"
const val TRACE_NAME = "Corda-Trace"
const val THREAD_POOL_NAME_NAME = "Corda-ThreadPoolName"
@ -30,7 +33,6 @@ class ArtemisTcpTransport {
// Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop.
// It does not use AMQP messages for its own messages e.g. topology and heartbeats.
private const val P2P_PROTOCOLS = "CORE,AMQP"
private const val RPC_PROTOCOLS = "CORE"
private fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort, protocols: String) = mapOf(
@ -39,46 +41,35 @@ class ArtemisTcpTransport {
TransportConstants.PORT_PROP_NAME to hostAndPort.port,
TransportConstants.PROTOCOLS_PROP_NAME to protocols,
TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null),
TransportConstants.REMOTING_THREADS_PROPNAME to (if (nodeSerializationEnv != null) -1 else 1),
// turn off direct delivery in Artemis - this is latency optimisation that can lead to
//hick-ups under high load (CORDA-1336)
TransportConstants.DIRECT_DELIVER to false)
private val defaultSSLOptions = mapOf(
TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","),
TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to TLS_VERSIONS.joinToString(","))
private fun SslConfiguration.addToTransportOptions(options: MutableMap<String, Any>) {
if (keyStore != null || trustStore != null) {
options[TransportConstants.SSL_ENABLED_PROP_NAME] = true
options[TransportConstants.NEED_CLIENT_AUTH_PROP_NAME] = true
}
keyStore?.let {
with (it) {
path.requireOnDefaultFileSystem()
options.putAll(get().toKeyStoreTransportOptions(path))
options[TransportConstants.KEYSTORE_TYPE_PROP_NAME] = "JKS"
options[TransportConstants.KEYSTORE_PATH_PROP_NAME] = path
options[TransportConstants.KEYSTORE_PASSWORD_PROP_NAME] = get().password
}
}
trustStore?.let {
with (it) {
path.requireOnDefaultFileSystem()
options.putAll(get().toTrustStoreTransportOptions(path))
options[TransportConstants.TRUSTSTORE_TYPE_PROP_NAME] = "JKS"
options[TransportConstants.TRUSTSTORE_PATH_PROP_NAME] = path
options[TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME] = get().password
}
}
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
options[SSL_HANDSHAKE_TIMEOUT_NAME] = handshakeTimeout ?: DEFAULT_SSL_HANDSHAKE_TIMEOUT
}
private fun CertificateStore.toKeyStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.KEYSTORE_TYPE_PROP_NAME to "JKS",
TransportConstants.KEYSTORE_PATH_PROP_NAME to path,
TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun CertificateStore.toTrustStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to "JKS",
TransportConstants.TRUSTSTORE_PATH_PROP_NAME to path,
TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun ClientRpcSslOptions.toTransportOptions() = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to trustStoreProvider,
@ -94,50 +85,64 @@ class ArtemisTcpTransport {
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: MutualSslConfiguration?,
trustManagerFactory: TrustManagerFactory? = config?.trustStore?.get()?.let(::trustManagerFactory),
enableSSL: Boolean = true,
threadPoolName: String = "P2PServer",
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (enableSSL) {
config?.addToTransportOptions(options)
}
return createAcceptorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace)
return createAcceptorTransport(
hostAndPort,
P2P_PROTOCOLS,
options,
trustManagerFactory,
enableSSL,
threadPoolName,
trace,
remotingThreads
)
}
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
config: MutualSslConfiguration?,
enableSSL: Boolean = true,
threadPoolName: String = "P2PClient",
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (enableSSL) {
config?.addToTransportOptions(options)
}
return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace)
return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace, remotingThreads)
}
fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: BrokerRpcSslOptions?,
enableSSL: Boolean = true,
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) {
config.keyStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions())
}
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCServer", trace)
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, null, enableSSL, "RPCServer", trace, remotingThreads)
}
fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
config: ClientRpcSslOptions?,
enableSSL: Boolean = true,
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) {
config.trustStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions())
}
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace, remotingThreads)
}
fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort,
@ -145,25 +150,42 @@ class ArtemisTcpTransport {
trace: Boolean = false): TransportConfiguration {
val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace, null)
}
fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: SslConfiguration,
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options)
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCServer", trace)
return createAcceptorTransport(
hostAndPort,
RPC_PROTOCOLS,
options,
trustManagerFactory(requireNotNull(config.trustStore).get()),
true,
"Internal-RPCServer",
trace,
remotingThreads
)
}
private fun createAcceptorTransport(hostAndPort: NetworkHostAndPort,
protocols: String,
options: MutableMap<String, Any>,
trustManagerFactory: TrustManagerFactory?,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean): TransportConfiguration {
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
// Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections
options[TransportConstants.HANDSHAKE_TIMEOUT] = 0
if (trustManagerFactory != null) {
// NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use
// more customised instances which use our revocation checkers, so we pass them in, to be picked up by Node(Open)SSLContextFactory.
options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory
}
return createTransport(
"net.corda.node.services.messaging.NodeNettyAcceptorFactory",
hostAndPort,
@ -171,7 +193,8 @@ class ArtemisTcpTransport {
options,
enableSSL,
threadPoolName,
trace
trace,
remotingThreads
)
}
@ -180,7 +203,12 @@ class ArtemisTcpTransport {
options: MutableMap<String, Any>,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean): TransportConfiguration {
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
if (enableSSL) {
// This is required to stop Client checking URL address vs. Server provided certificate
options[TransportConstants.VERIFY_HOST_PROP_NAME] = false
}
return createTransport(
NodeNettyConnectorFactory::class.java.name,
hostAndPort,
@ -188,7 +216,8 @@ class ArtemisTcpTransport {
options,
enableSSL,
threadPoolName,
trace
trace,
remotingThreads
)
}
@ -198,13 +227,15 @@ class ArtemisTcpTransport {
options: MutableMap<String, Any>,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean): TransportConfiguration {
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
options += defaultArtemisOptions(hostAndPort, protocols)
if (enableSSL) {
options += defaultSSLOptions
// This is required to stop Client checking URL address vs. Server provided certificate
options[TransportConstants.VERIFY_HOST_PROP_NAME] = false
options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",")
options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",")
}
// By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357)
options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1
options[THREAD_POOL_NAME_NAME] = threadPoolName
options[TRACE_NAME] = trace
return TransportConfiguration(className, options)

View File

@ -0,0 +1,32 @@
@file:Suppress("LongParameterList", "MagicNumber")
package net.corda.nodeapi.internal
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.utilities.seconds
import java.time.Duration
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
/**
* Creates a [ThreadPoolExecutor] which will use a maximum of [maxPoolSize] threads at any given time and will by default idle down to 0
* threads.
*/
fun namedThreadPoolExecutor(maxPoolSize: Int,
corePoolSize: Int = 0,
idleKeepAlive: Duration = 30.seconds,
workQueue: BlockingQueue<Runnable> = LinkedBlockingQueue(),
poolName: String = "pool",
daemonThreads: Boolean = false,
threadPriority: Int = Thread.NORM_PRIORITY): ThreadPoolExecutor {
return ThreadPoolExecutor(
corePoolSize,
maxPoolSize,
idleKeepAlive.toNanos(),
TimeUnit.NANOSECONDS,
workQueue,
DefaultThreadFactory(poolName, daemonThreads, threadPriority)
)
}

View File

@ -100,7 +100,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private class AMQPBridge(val sourceX500Name: String,
val queueName: String,
val targets: List<NetworkHostAndPort>,
val legalNames: Set<CordaX500Name>,
val allowedRemoteLegalNames: Set<CordaX500Name>,
private val amqpConfig: AMQPConfiguration,
sharedEventGroup: EventLoopGroup,
private val artemis: ArtemisSessionProvider,
@ -116,7 +116,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
MDC.put("queueName", queueName)
MDC.put("source", amqpConfig.sourceX500Name)
MDC.put("targets", targets.joinToString(separator = ";") { it.toString() })
MDC.put("legalNames", legalNames.joinToString(separator = ";") { it.toString() })
MDC.put("allowedRemoteLegalNames", allowedRemoteLegalNames.joinToString(separator = ";") { it.toString() })
MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString())
block()
} finally {
@ -134,7 +134,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) }
val amqpClient = AMQPClient(targets, legalNames, amqpConfig, sharedThreadPool = sharedEventGroup)
val amqpClient = AMQPClient(targets, allowedRemoteLegalNames, amqpConfig, sharedThreadPool = sharedEventGroup)
private var session: ClientSession? = null
private var consumer: ClientConsumer? = null
private var connectedSubscription: Subscription? = null
@ -231,7 +231,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
}
ArtemisState.STOPPING
}
bridgeMetricsService?.bridgeDisconnected(targets, legalNames)
bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
connectedSubscription?.unsubscribe()
connectedSubscription = null
// Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first.
@ -243,7 +243,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
if (connected) {
logInfoWithMDC("Bridge Connected")
bridgeMetricsService?.bridgeConnected(targets, legalNames)
bridgeMetricsService?.bridgeConnected(targets, allowedRemoteLegalNames)
if (bridgeConnectionTTLSeconds > 0) {
// AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval
amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS,
@ -286,7 +286,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
logInfoWithMDC("Bridge Disconnected")
amqpRestartEvent?.cancel(false)
if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) {
bridgeMetricsService?.bridgeDisconnected(targets, legalNames)
bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
}
artemis(ArtemisState.STOPPING) { precedingState: ArtemisState ->
logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected")
@ -418,10 +418,10 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
properties[key] = value
}
}
logDebugWithMDC { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
logDebugWithMDC { "Bridged Send to ${allowedRemoteLegalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
val peerInbox = translateLocalQueueToInboxAddress(queueName)
val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox,
legalNames.first().toString(),
allowedRemoteLegalNames.first().toString(),
properties)
sendableMessage.onComplete.then {
logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" }
@ -486,7 +486,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
queueNamesToBridgesMap.remove(queueName)
}
bridge.stop()
bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.legalNames)
bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.allowedRemoteLegalNames)
}
}
}
@ -498,7 +498,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
val bridges = queueNamesToBridgesMap[queueName]?.toList()
destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList())
bridges?.map {
it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.legalNames.toList(), serviceAddress = false)
it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.allowedRemoteLegalNames.toList(), serviceAddress = false)
}?.toMap() ?: emptyMap()
}
}

View File

@ -5,16 +5,37 @@ package net.corda.nodeapi.internal.crypto
import net.corda.core.CordaOID
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.newSecureRandom
import net.corda.core.internal.*
import net.corda.core.internal.CertRole
import net.corda.core.internal.SignedDataWithCert
import net.corda.core.internal.reader
import net.corda.core.internal.signWithCert
import net.corda.core.internal.uncheckedCast
import net.corda.core.internal.validate
import net.corda.core.internal.writer
import net.corda.core.utilities.days
import net.corda.core.utilities.millis
import net.corda.core.utilities.toHex
import net.corda.nodeapi.internal.protonwrapper.netty.distributionPointsToString
import org.bouncycastle.asn1.*
import org.bouncycastle.asn1.ASN1EncodableVector
import org.bouncycastle.asn1.ASN1ObjectIdentifier
import org.bouncycastle.asn1.ASN1Sequence
import org.bouncycastle.asn1.DERSequence
import org.bouncycastle.asn1.DERUTF8String
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x509.*
import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.DistributionPoint
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.KeyPurposeId
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.asn1.x509.NameConstraints
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.X509v3CertificateBuilder
import org.bouncycastle.cert.bc.BcX509ExtensionUtils
@ -32,8 +53,13 @@ import java.nio.file.Path
import java.security.KeyPair
import java.security.PublicKey
import java.security.SignatureException
import java.security.cert.*
import java.security.cert.CertPath
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.CertificateFactory
import java.security.cert.TrustAnchor
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.time.Duration
import java.time.Instant
import java.time.temporal.ChronoUnit
@ -359,7 +385,7 @@ object X509Utilities {
private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) {
if (crlDistPoint != null) {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint)))
val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier))
val crlIssuerGeneralNames = crlIssuer?.let {
GeneralNames(GeneralName(crlIssuer))
}
@ -379,6 +405,8 @@ object X509Utilities {
bytes[0] = bytes[0].and(0x3F).or(0x40)
return BigInteger(bytes)
}
fun toGeneralNames(string: String, tag: Int = GeneralName.directoryName): GeneralNames = GeneralNames(GeneralName(tag, string))
}
// Assuming cert type to role is 1:1

View File

@ -27,15 +27,14 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME
import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import rx.Observable
import rx.subjects.PublishSubject
import java.lang.Long.min
import java.net.InetSocketAddress
import java.security.cert.CertPathValidatorException
import java.util.concurrent.ExecutorService
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock
enum class ProxyVersion {
@ -63,7 +62,8 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
val allowedRemoteLegalNames: Set<CordaX500Name>,
private val configuration: AMQPConfiguration,
private val sharedThreadPool: EventLoopGroup? = null,
private val threadPoolName: String = "AMQPClient") : AutoCloseable {
private val threadPoolName: String = "AMQPClient",
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON) : AutoCloseable {
companion object {
init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
@ -89,12 +89,12 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
private var targetIndex = 0
private var currentTarget: NetworkHostAndPort = targets.first()
private var retryInterval = MIN_RETRY_INTERVAL
private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker()
private val badCertTargets = mutableSetOf<NetworkHostAndPort>()
@Volatile
private var amqpActive = false
@Volatile
private var amqpChannelHandler: ChannelHandler? = null
private var sslDelegatedTaskExecutor: ExecutorService? = null
val localAddressString: String
get() = clientChannel?.localAddress()?.toString() ?: "<unknownLocalAddress>"
@ -150,17 +150,16 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
}
private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration
@Volatile
private lateinit var amqpChannelHandler: AMQPChannelHandler
init {
keyManagerFactory.init(conf.keyStore)
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker))
}
@Suppress("ComplexMethod")
override fun initChannel(ch: SocketChannel) {
val pipeline = ch.pipeline()
@ -199,10 +198,24 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration)
val target = parent.currentTarget
val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor)
val handler = if (parent.configuration.useOpenSsl) {
createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc())
createClientOpenSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
ch.alloc(),
delegatedTaskExecutor
)
} else {
createClientSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory)
createClientSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
delegatedTaskExecutor
)
}
handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis()
pipeline.addLast("sslHandler", handler)
@ -260,6 +273,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
return
}
log.info("Connect to: $currentTarget")
sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY))
started = true
restart()
@ -294,6 +308,8 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
}
clientChannel = null
workerGroup = null
sslDelegatedTaskExecutor?.shutdown()
sslDelegatedTaskExecutor = null
log.info("Stopped connection to $currentTarget")
}
}
@ -334,6 +350,4 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized()
val onConnection: Observable<ConnectionChange>
get() = _onConnection
val softFailExceptions: List<CertPathValidatorException> get() = revocationChecker.softFailExceptions
}
}

View File

@ -21,16 +21,15 @@ import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import org.apache.qpid.proton.engine.Delivery
import rx.Observable
import rx.subjects.PublishSubject
import java.net.BindException
import java.net.InetSocketAddress
import java.security.cert.CertPathValidatorException
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock
/**
@ -39,37 +38,34 @@ import kotlin.concurrent.withLock
class AMQPServer(val hostName: String,
val port: Int,
private val configuration: AMQPConfiguration,
private val threadPoolName: String = "AMQPServer") : AutoCloseable {
private val threadPoolName: String = "AMQPServer",
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON,
private val remotingThreads: Int? = null) : AutoCloseable {
companion object {
init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
}
private const val CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME = "net.corda.nodeapi.amqpserver.NumServerThreads"
private val log = contextLogger()
private val NUM_SERVER_THREADS = Integer.getInteger(CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME, 4)
private val DEFAULT_REMOTING_THREADS = Integer.getInteger("net.corda.nodeapi.amqpserver.NumServerThreads", 4)
}
private val lock = ReentrantLock()
@Volatile
private var stopping: Boolean = false
private var bossGroup: EventLoopGroup? = null
private var workerGroup: EventLoopGroup? = null
private var serverChannel: Channel? = null
private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker()
private var sslDelegatedTaskExecutor: ExecutorService? = null
private val clientChannels = ConcurrentHashMap<InetSocketAddress, SocketChannel>()
private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration
init {
keyManagerFactory.init(conf.keyStore.value.internal, conf.keyStore.entryPassword.toCharArray())
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker))
}
override fun initChannel(ch: SocketChannel) {
val amqpConfiguration = parent.configuration
val pipeline = ch.pipeline()
@ -116,11 +112,12 @@ class AMQPServer(val hostName: String,
Pair(createServerSNIOpenSniHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap)
} else {
val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig)
val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor)
val handler = if (amqpConfig.useOpenSsl) {
createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc())
createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc(), delegatedTaskExecutor)
} else {
// For javaSSL, SNI matching is handled at key manager level.
createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory)
createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory, delegatedTaskExecutor)
}
handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout.toMillis()
Pair(handler, mapOf(DEFAULT to keyManagerFactory))
@ -132,8 +129,13 @@ class AMQPServer(val hostName: String,
lock.withLock {
stop()
sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("$threadPoolName-boss", Thread.MAX_PRIORITY))
workerGroup = NioEventLoopGroup(NUM_SERVER_THREADS, DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY))
workerGroup = NioEventLoopGroup(
remotingThreads ?: DEFAULT_REMOTING_THREADS,
DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY)
)
val server = ServerBootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
@ -154,22 +156,19 @@ class AMQPServer(val hostName: String,
fun stop() {
lock.withLock {
try {
stopping = true
serverChannel?.apply { close() }
serverChannel = null
serverChannel?.close()
serverChannel = null
workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync()
workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync()
workerGroup = null
bossGroup?.shutdownGracefully()
bossGroup?.terminationFuture()?.sync()
bossGroup?.shutdownGracefully()
bossGroup?.terminationFuture()?.sync()
bossGroup = null
workerGroup = null
bossGroup = null
} finally {
stopping = false
}
sslDelegatedTaskExecutor?.shutdown()
sslDelegatedTaskExecutor = null
}
}
@ -226,6 +225,4 @@ class AMQPServer(val hostName: String,
private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized()
val onConnection: Observable<ConnectionChange>
get() = _onConnection
val softFailExceptions: List<CertPathValidatorException> get() = revocationChecker.softFailExceptions
}

View File

@ -31,4 +31,6 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() {
override fun getSoftFailExceptions(): List<CertPathValidatorException> {
return Collections.emptyList()
}
override fun clone(): AllowAllRevocationChecker = this
}

View File

@ -3,9 +3,6 @@ package net.corda.nodeapi.internal.protonwrapper.netty
import com.typesafe.config.Config
import net.corda.nodeapi.internal.config.ConfigParser
import net.corda.nodeapi.internal.config.CustomConfigParser
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import java.security.cert.PKIXRevocationChecker
/**
* Data structure for controlling the way how Certificate Revocation Lists are handled.
@ -45,18 +42,6 @@ interface RevocationConfig {
* Optional [CrlSource] which only makes sense with `mode` = `EXTERNAL_SOURCE`
*/
val externalCrlSource: CrlSource?
fun createPKIXRevocationChecker(): PKIXRevocationChecker {
return when (mode) {
Mode.OFF -> AllowAllRevocationChecker
Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(externalCrlSource) { "externalCrlSource must be specfied for EXTERNAL_SOURCE" }
CordaRevocationChecker(externalCrlSource, softFail = true)
}
Mode.SOFT_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = true)
Mode.HARD_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = false)
}
}
}
/**

View File

@ -1,3 +1,5 @@
@file:Suppress("ComplexMethod", "LongParameterList")
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.buffer.ByteBufAllocator
@ -18,6 +20,8 @@ import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.namedThreadPoolExecutor
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import org.bouncycastle.asn1.ASN1InputStream
import org.bouncycastle.asn1.ASN1Primitive
import org.bouncycastle.asn1.DERIA5String
@ -34,10 +38,10 @@ import java.net.URI
import java.security.KeyStore
import java.security.cert.CertificateException
import java.security.cert.PKIXBuilderParameters
import java.security.cert.PKIXRevocationChecker
import java.security.cert.X509CertSelector
import java.security.cert.X509Certificate
import java.util.concurrent.Executor
import java.util.concurrent.ThreadPoolExecutor
import javax.net.ssl.CertPathTrustManagerParameters
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName
@ -46,7 +50,6 @@ import javax.net.ssl.SSLEngine
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509ExtendedTrustManager
import javax.security.auth.x500.X500Principal
import kotlin.system.measureTimeMillis
private const val HOSTNAME_FORMAT = "%s.corda.net"
internal const val DEFAULT = "default"
@ -58,7 +61,6 @@ internal val logger = LoggerFactory.getLogger("net.corda.nodeapi.internal.proton
/**
* Returns all the CRL distribution points in the certificate as [URI]s along with the CRL issuer names, if any.
*/
@Suppress("ComplexMethod")
fun X509Certificate.distributionPoints(): Map<URI, List<X500Principal>?> {
logger.debug { "Checking CRLDPs for $subjectX500Principal" }
@ -117,6 +119,14 @@ fun certPathToString(certPath: Array<out X509Certificate>?): String {
return certPath.joinToString(System.lineSeparator()) { " ${it.toSimpleString()}" }
}
/**
* Create an executor for processing SSL handshake tasks asynchronously (see [SSLEngine.getDelegatedTask]). The max number of threads is 3,
* which is the typical number of CRLs expected in a Corda TLS cert path. The executor needs to be passed to the [SslHandler] constructor.
*/
fun sslDelegatedTaskExecutor(parentPoolName: String): ThreadPoolExecutor {
return namedThreadPoolExecutor(maxPoolSize = 3, poolName = "$parentPoolName-ssltask")
}
@VisibleForTesting
class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() {
companion object {
@ -179,32 +189,11 @@ class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509Ex
}
private object LoggingImmediateExecutor : Executor {
override fun execute(command: Runnable) {
val log = LoggerFactory.getLogger(javaClass)
@Suppress("TooGenericExceptionCaught", "MagicNumber") // log and rethrow all exceptions
try {
val commandName = command::class.qualifiedName?.let { "[$it]" } ?: ""
log.debug("Entering SSL command $commandName")
val elapsedTime = measureTimeMillis { command.run() }
log.debug("Exiting SSL command $elapsedTime millis")
if (elapsedTime > 100) {
log.info("Command: $commandName took $elapsedTime millis to execute")
}
}
catch (ex: Exception) {
log.error("Caught exception in SSL handler executor", ex)
throw ex
}
}
}
internal fun createClientSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine(target.host, target.port)
sslEngine.useClientMode = true
@ -216,14 +205,15 @@ internal fun createClientSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters
}
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler {
alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build()
val sslEngine = sslContext.newEngine(alloc, target.host, target.port)
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
@ -233,12 +223,13 @@ internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters
}
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
internal fun createServerSslHandler(keyStore: CertificateStore,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine()
sslEngine.useClientMode = false
@ -249,39 +240,29 @@ internal fun createServerSslHandler(keyStore: CertificateStore,
val sslParameters = sslEngine.sslParameters
sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore))
sslEngine.sslParameters = sslParameters
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler {
alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build()
val sslEngine = sslContext.newEngine(alloc)
sslEngine.useClientMode = false
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SSLContext {
fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory?): SSLContext {
val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java)
.map { LoggingTrustManagerWrapper(it) }.toTypedArray()
sslContext.init(keyManagers, trustManagers, newSecureRandom())
val trustManagers = trustManagerFactory
?.trustManagers
?.map { if (it is X509ExtendedTrustManager) LoggingTrustManagerWrapper(it) else it }
?.toTypedArray()
sslContext.init(keyManagerFactory.keyManagers, trustManagers, newSecureRandom())
return sslContext
}
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore,
revocationConfig: RevocationConfig): CertPathTrustManagerParameters {
return initialiseTrustStoreAndEnableCrlChecking(trustStore, revocationConfig.createPKIXRevocationChecker())
}
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore,
revocationChecker: PKIXRevocationChecker): CertPathTrustManagerParameters {
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
return CertPathTrustManagerParameters(pkixParams)
}
/**
* Creates a special SNI handler used only when openSSL is used for AMQPServer
*/
@ -296,14 +277,13 @@ internal fun createServerSNIOpenSniHandler(keyManagerFactoriesMap: Map<String, K
return SniHandler(mapping.build())
}
@Suppress("SpreadOperator")
private fun getServerSslContextBuilder(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslContextBuilder {
return SslContextBuilder.forServer(keyManagerFactory)
.sslProvider(SslProvider.OPENSSL)
.trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory))
.clientAuth(ClientAuth.REQUIRE)
.ciphers(ArtemisTcpTransport.CIPHER_SUITES)
.protocols(*ArtemisTcpTransport.TLS_VERSIONS.toTypedArray())
.protocols(ArtemisTcpTransport.TLS_VERSIONS)
}
internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKeyManagerFactoryWrapper> {
@ -327,7 +307,38 @@ internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKe
// 2nd parameter `password` - the password for recovering keys in the KeyStore
fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.entryPassword.toCharArray())
fun TrustManagerFactory.init(trustStore: CertificateStore) = init(trustStore.value.internal)
fun keyManagerFactory(keyStore: CertificateStore): KeyManagerFactory {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
return keyManagerFactory
}
fun trustManagerFactory(trustStore: CertificateStore): TrustManagerFactory {
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore.value.internal)
return trustManagerFactory
}
fun trustManagerFactoryWithRevocation(trustStore: CertificateStore,
revocationConfig: RevocationConfig,
crlSource: CrlSource): TrustManagerFactory {
val revocationChecker = when (revocationConfig.mode) {
RevocationConfig.Mode.OFF -> AllowAllRevocationChecker
RevocationConfig.Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(revocationConfig.externalCrlSource) {
"externalCrlSource must be specfied for EXTERNAL_SOURCE"
}
CordaRevocationChecker(externalCrlSource, softFail = true)
}
RevocationConfig.Mode.SOFT_FAIL -> CordaRevocationChecker(crlSource, softFail = true)
RevocationConfig.Mode.HARD_FAIL -> CordaRevocationChecker(crlSource, softFail = false)
}
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(CertPathTrustManagerParameters(pkixParams))
return trustManagerFactory
}
/**
* Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target

View File

@ -5,6 +5,9 @@ import com.github.benmanes.caffeine.cache.LoadingCache
import net.corda.core.internal.readFully
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
@ -12,60 +15,71 @@ import net.corda.nodeapi.internal.protonwrapper.netty.distributionPoints
import java.net.URI
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit
import java.time.Duration
import javax.security.auth.x500.X500Principal
/**
* [CrlSource] which downloads CRLs from the distribution points in the X509 certificate.
* [CrlSource] which downloads CRLs from the distribution points in the X509 certificate and caches them.
*/
@Suppress("TooGenericExceptionCaught")
class CertDistPointCrlSource : CrlSource {
class CertDistPointCrlSource(cacheSize: Long = DEFAULT_CACHE_SIZE,
cacheExpiry: Duration = DEFAULT_CACHE_EXPIRY,
private val connectTimeout: Duration = DEFAULT_CONNECT_TIMEOUT,
private val readTimeout: Duration = DEFAULT_READ_TIMEOUT) : CrlSource {
companion object {
private val logger = contextLogger()
// The default SSL handshake timeout is 60s (DEFAULT_SSL_HANDSHAKE_TIMEOUT). Considering there are 3 CRLs endpoints to check in a
// node handshake, we want to keep the total timeout within that.
private const val DEFAULT_CONNECT_TIMEOUT = 9_000
private const val DEFAULT_READ_TIMEOUT = 9_000
private val DEFAULT_CONNECT_TIMEOUT = 9.seconds
private val DEFAULT_READ_TIMEOUT = 9.seconds
private const val DEFAULT_CACHE_SIZE = 185L // Same default as the JDK (URICertStore)
private const val DEFAULT_CACHE_EXPIRY = 5 * 60 * 1000L
private val DEFAULT_CACHE_EXPIRY = 5.minutes
private val cache: LoadingCache<URI, X509CRL> = Caffeine.newBuilder()
.maximumSize(java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE))
.expireAfterWrite(java.lang.Long.getLong("net.corda.dpcrl.cache.expiry", DEFAULT_CACHE_EXPIRY), TimeUnit.MILLISECONDS)
.build(::retrieveCRL)
val SINGLETON = CertDistPointCrlSource(
cacheSize = java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE),
cacheExpiry = java.lang.Long.getLong("net.corda.dpcrl.cache.expiry")?.let(Duration::ofMillis) ?: DEFAULT_CACHE_EXPIRY,
connectTimeout = java.lang.Long.getLong("net.corda.dpcrl.connect.timeout")?.let(Duration::ofMillis) ?: DEFAULT_CONNECT_TIMEOUT,
readTimeout = java.lang.Long.getLong("net.corda.dpcrl.read.timeout")?.let(Duration::ofMillis) ?: DEFAULT_READ_TIMEOUT
)
}
private val connectTimeout = Integer.getInteger("net.corda.dpcrl.connect.timeout", DEFAULT_CONNECT_TIMEOUT)
private val readTimeout = Integer.getInteger("net.corda.dpcrl.read.timeout", DEFAULT_READ_TIMEOUT)
private val cache: LoadingCache<URI, X509CRL> = Caffeine.newBuilder()
.maximumSize(cacheSize)
.expireAfterWrite(cacheExpiry)
.build(::retrieveCRL)
private fun retrieveCRL(uri: URI): X509CRL {
val start = System.currentTimeMillis()
val bytes = try {
val conn = uri.toURL().openConnection()
conn.connectTimeout = connectTimeout
conn.readTimeout = readTimeout
// Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes
// in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException
// into CRLException and drops the cause chain.
conn.getInputStream().readFully()
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e)
}
throw e
private fun retrieveCRL(uri: URI): X509CRL {
val start = System.currentTimeMillis()
val bytes = try {
val conn = uri.toURL().openConnection()
conn.connectTimeout = connectTimeout.toMillis().toInt()
conn.readTimeout = readTimeout.toMillis().toInt()
// Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes
// in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException
// into CRLException and drops the cause chain.
conn.getInputStream().readFully()
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e)
}
val duration = System.currentTimeMillis() - start
val crl = try {
X509CertificateFactory().generateCRL(bytes.inputStream())
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Invalid CRL from $uri (${duration}ms)", e)
}
throw e
}
logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" }
return crl
throw e
}
val duration = System.currentTimeMillis() - start
val crl = try {
X509CertificateFactory().generateCRL(bytes.inputStream())
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Invalid CRL from $uri (${duration}ms)", e)
}
throw e
}
logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" }
return crl
}
fun clearCache() {
cache.invalidateAll()
}
override fun fetch(certificate: X509Certificate): Set<X509CRL> {

View File

@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.assertj.core.api.Assertions
import org.junit.Rule
import org.junit.Test
@ -161,11 +162,9 @@ class TlsDiffAlgorithmsTest(private val serverAlgo: String, private val clientAl
private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext {
return SSLContext.getInstance("TLS").apply {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagerFactory = keyManagerFactory(keyStore)
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustMgrFactory = trustManagerFactory(trustStore)
val trustManagers = trustMgrFactory.trustManagers
init(keyManagers, trustManagers, newSecureRandom())
}

View File

@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.assertj.core.api.Assertions
import org.junit.Ignore
import org.junit.Rule
@ -18,7 +19,6 @@ import java.io.IOException
import java.net.InetAddress
import java.net.InetSocketAddress
import javax.net.ssl.*
import javax.net.ssl.SNIHostName
import kotlin.concurrent.thread
import kotlin.test.assertEquals
import kotlin.test.assertFalse
@ -209,11 +209,9 @@ class TlsDiffProtocolsTest(private val serverAlgo: String, private val clientAlg
private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext {
return SSLContext.getInstance("TLS").apply {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagerFactory = keyManagerFactory(keyStore)
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustMgrFactory = trustManagerFactory(trustStore)
val trustManagers = trustMgrFactory.trustManagers
init(keyManagers, trustManagers, newSecureRandom())
}

View File

@ -1,5 +1,6 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.util.concurrent.ImmediateExecutor
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.NetworkHostAndPort
@ -8,10 +9,9 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
import net.corda.testing.internal.fixedCrlSource
import org.junit.Test
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName
import javax.net.ssl.TrustManagerFactory
import kotlin.test.assertEquals
class SSLHelperTest {
@ -20,15 +20,21 @@ class SSLHelperTest {
val legalName = CordaX500Name("Test", "London", "GB")
val sslConfig = configureTestSSL(legalName)
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val keyStore = sslConfig.keyStore
keyManagerFactory.init(CertificateStore.fromFile(keyStore.path, keyStore.storePassword, keyStore.entryPassword, false))
val trustStore = sslConfig.trustStore
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(CertificateStore.fromFile(trustStore.path, trustStore.storePassword, trustStore.entryPassword, false), RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL)))
val trustManagerFactory = trustManagerFactoryWithRevocation(
sslConfig.trustStore.get(),
RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL),
fixedCrlSource(emptySet())
)
val sslHandler = createClientSslHandler(NetworkHostAndPort("localhost", 1234), setOf(legalName), keyManagerFactory, trustManagerFactory)
val sslHandler = createClientSslHandler(
NetworkHostAndPort("localhost", 1234),
setOf(legalName),
keyManagerFactory,
trustManagerFactory,
ImmediateExecutor.INSTANCE
)
val legalNameHash = SecureHash.sha256(legalName.toString()).toString().take(32).toLowerCase()
// These hardcoded values must not be changed, something is broken if you have to change these hardcoded values.

View File

@ -2,15 +2,13 @@ package net.corda.nodeapi.internal.revocation
import net.corda.core.crypto.Crypto
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.testing.core.ALICE_NAME
import net.corda.nodeapi.internal.DEV_INTERMEDIATE_CA
import net.corda.testing.node.internal.network.CrlServer
import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.math.BigInteger
class CertDistPointCrlSourceTest {
private lateinit var crlServer: CrlServer
@ -39,13 +37,14 @@ class CertDistPointCrlSourceTest {
assertThat(single().revokedCertificates).isNull()
}
val nodeCaCert = crlServer.replaceNodeCertDistPoint(createDevNodeCa(crlServer.intermediateCa, ALICE_NAME).certificate)
crlSource.clearCache()
crlServer.revokedNodeCerts += listOf(BigInteger.ONE, BigInteger.TEN)
with(crlSource.fetch(nodeCaCert)) { // Use a different cert to avoid the cache
crlServer.revokedIntermediateCerts += DEV_INTERMEDIATE_CA.certificate
with(crlSource.fetch(crlServer.intermediateCa.certificate)) {
assertThat(size).isEqualTo(1)
val revokedCertificates = single().revokedCertificates
assertThat(revokedCertificates.map { it.serialNumber }).containsExactlyInAnyOrder(BigInteger.ONE, BigInteger.TEN)
// This also tests clearCache() works.
assertThat(revokedCertificates.map { it.serialNumber }).containsExactly(DEV_INTERMEDIATE_CA.certificate.serialNumber)
}
}
}

View File

@ -5,7 +5,7 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
import net.corda.testing.internal.fixedCrlSource
import org.bouncycastle.jcajce.provider.asymmetric.x509.CertificateFactory
import org.junit.Test
import java.math.BigInteger
@ -41,10 +41,8 @@ class CordaRevocationCheckerTest {
val resourceAsStream = javaClass.getResourceAsStream("/net/corda/nodeapi/internal/protonwrapper/netty/doorman.crl")
val crl = CertificateFactory().engineGenerateCRL(resourceAsStream) as X509CRL
val crlSource = object : CrlSource {
override fun fetch(certificate: X509Certificate): Set<X509CRL> = setOf(crl)
}
val checker = CordaRevocationChecker(crlSource,
val checker = CordaRevocationChecker(
crlSource = fixedCrlSource(setOf(crl)),
softFail = true,
dateSource = { Date.from(date.atStartOfDay().toInstant(ZoneOffset.UTC)) }
)

View File

@ -1,20 +1,16 @@
package net.corda.node.internal.artemis
package net.corda.nodeapi.internal.revocation
import net.corda.core.crypto.Crypto
import net.corda.core.utilities.days
import net.corda.node.internal.artemis.CertificateChainCheckPolicy.RevocationCheck
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509KeyStore
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.testing.core.createCRL
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x509.CRLReason
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.ExtensionsGenerator
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.IssuingDistributionPoint
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.junit.Before
import org.junit.Rule
import org.junit.Test
@ -22,15 +18,18 @@ import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import java.io.File
import java.security.KeyPair
import java.security.KeyStore
import java.security.PrivateKey
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.*
import javax.net.ssl.X509TrustManager
import javax.security.auth.x500.X500Principal
import kotlin.test.assertFails
import kotlin.test.assertFailsWith
@RunWith(Parameterized::class)
class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
class RevocationTest(private val revocationMode: RevocationConfig.Mode) {
companion object {
@JvmStatic
@Parameterized.Parameters(name = "revocationMode = {0}")
@ -45,8 +44,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
private lateinit var doormanCRL: File
private lateinit var tlsCRL: File
private val keyStore = KeyStore.getInstance("JKS")
private val trustStore = KeyStore.getInstance("JKS")
private lateinit var trustManager: X509TrustManager
private val rootKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private val tlsCRLIssuerKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
@ -61,7 +59,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
private lateinit var tlsCert: X509Certificate
private val chain
get() = listOf(tlsCert, nodeCACert, doormanCert, rootCert).toTypedArray()
get() = arrayOf(tlsCert, nodeCACert, doormanCert, rootCert)
@Before
fun before() {
@ -72,10 +70,18 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
rootCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=root"), rootKeyPair)
tlsCRLIssuerCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=issuer"), tlsCRLIssuerKeyPair)
val trustStore = KeyStore.getInstance("JKS")
trustStore.load(null, null)
trustStore.setCertificateEntry("cordatlscrlsigner", tlsCRLIssuerCert)
trustStore.setCertificateEntry("cordarootca", rootCert)
val trustManagerFactory = trustManagerFactoryWithRevocation(
CertificateStore.of(X509KeyStore(trustStore, "pass"), "pass", "pass"),
RevocationConfigImpl(revocationMode),
CertDistPointCrlSource()
)
trustManager = trustManagerFactory.trustManagers.single() as X509TrustManager
doormanCert = X509Utilities.createCertificate(
CertificateType.INTERMEDIATE_CA, rootCert, rootKeyPair, X500Principal("CN=doorman"), doormanKeyPair.public,
crlDistPoint = rootCRL.toURI().toString()
@ -89,43 +95,34 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
)
rootCRL.createCRL(rootCert, rootKeyPair.private, false)
doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false)
tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true)
rootCRL.writeCRL(rootCert, rootKeyPair.private, false)
doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false)
tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true)
}
private fun File.createCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) {
val builder = JcaX509v2CRLBuilder(certificate.subjectX500Principal, Date())
builder.setNextUpdate(Date.from(Date().toInstant() + 7.days))
builder.addExtension(Extension.issuingDistributionPoint, true, IssuingDistributionPoint(null, indirect, false))
revoked.forEach {
val extensionsGenerator = ExtensionsGenerator()
extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(CRLReason.keyCompromise))
// Certificate issuer is required for indirect CRL
val certificateIssuerName = X500Name.getInstance(it.issuerX500Principal.encoded)
extensionsGenerator.addExtension(Extension.certificateIssuer, true, GeneralNames(GeneralName(certificateIssuerName)))
builder.addCRLEntry(it.serialNumber, Date(), extensionsGenerator.generate())
}
val holder = builder.build(JcaContentSignerBuilder("SHA256withECDSA").setProvider(Crypto.findProvider("BC")).build(privateKey))
outputStream().use { it.write(holder.encoded) }
private fun File.writeCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) {
val crl = createCRL(
CertificateAndKeyPair(certificate, KeyPair(certificate.publicKey, privateKey)),
revoked.asList(),
indirect = indirect
)
writeBytes(crl.encoded)
}
private fun assertFailsFor(vararg modes: RevocationConfig.Mode, block: () -> Unit) {
if (revocationMode in modes) assertFails(block) else block()
private fun assertFailsFor(vararg modes: RevocationConfig.Mode) {
if (revocationMode in modes) assertFailsWith(CertificateException::class, ::doRevocationCheck) else doRevocationCheck()
}
@Test(timeout = 300_000)
fun `ok with empty CRLs`() {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
doRevocationCheck()
}
@Test(timeout = 300_000)
fun `soft fail with revoked TLS certificate`() {
tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert)
tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert)
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
@ -136,9 +133,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
@ -148,9 +143,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name("CN=unknown")
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
@ -160,9 +153,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = tlsCRL.toURI().toString()
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
@ -172,18 +163,16 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=other"), otherKeyPair.public,
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
)
tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert)
tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
doRevocationCheck()
}
@Test(timeout = 300_000)
fun `soft fail with revoked node CA certificate`() {
doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, nodeCACert)
doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, nodeCACert)
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
@ -193,9 +182,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/doorman"
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
@ -205,8 +192,12 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=other"), otherKeyPair.public,
crlDistPoint = doormanCRL.toURI().toString()
)
doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, otherCert)
doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, otherCert)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
doRevocationCheck()
}
private fun doRevocationCheck() {
trustManager.checkClientTrusted(chain, "ECDHE_ECDSA")
}
}

View File

@ -269,8 +269,6 @@ tasks.register('integrationTest', Test) {
testClassesDirs = sourceSets.integrationTest.output.classesDirs
classpath = sourceSets.integrationTest.runtimeClasspath
maxParallelForks = (System.env.CORDA_NODE_INT_TESTING_FORKS == null) ? 1 : "$System.env.CORDA_NODE_INT_TESTING_FORKS".toInteger()
// CertificateRevocationListNodeTests
systemProperty 'net.corda.dpcrl.connect.timeout', '4000'
}
tasks.register('slowIntegrationTest', Test) {

View File

@ -14,12 +14,15 @@ import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.initialiseTrustStoreAndEnableCrlChecking
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.testing.internal.fixedCrlSource
import org.junit.Assume.assumeFalse
import org.junit.Before
import org.junit.Rule
@ -96,11 +99,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) {
override val maxMessageSize: Int = MAX_MESSAGE_SIZE
}
serverKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
serverTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
serverKeyManagerFactory = keyManagerFactory(keyStore)
serverKeyManagerFactory.init(keyStore)
serverTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(serverAmqpConfig.trustStore, serverAmqpConfig.revocationConfig))
serverTrustManagerFactory = trustManagerFactoryWithRevocation(
serverAmqpConfig.trustStore,
RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL),
fixedCrlSource(emptySet())
)
}
private fun setupClientCertificates() {
@ -127,11 +132,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) {
override val sslHandshakeTimeout: Duration = 3.seconds
}
clientKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
clientTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
clientKeyManagerFactory = keyManagerFactory(keyStore)
clientKeyManagerFactory.init(keyStore)
clientTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(clientAmqpConfig.trustStore, clientAmqpConfig.revocationConfig))
clientTrustManagerFactory = trustManagerFactoryWithRevocation(
clientAmqpConfig.trustStore,
RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL),
fixedCrlSource(emptySet())
)
}
@Test(timeout = 300_000)

View File

@ -1,3 +1,5 @@
@file:Suppress("LongParameterList")
package net.corda.node.amqp
import com.nhaarman.mockito_kotlin.doReturn
@ -5,10 +7,10 @@ import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.crypto.Crypto
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div
import net.corda.core.internal.rootCause
import net.corda.core.internal.times
import net.corda.core.toFuture
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.node.services.config.NodeConfiguration
@ -18,64 +20,68 @@ import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.config.CertificateStoreSupplier
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_CA
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionChange
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.CHARLIE_NAME
import net.corda.testing.core.MAX_MESSAGE_SIZE
import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.testing.node.internal.network.CrlServer
import net.corda.testing.node.internal.network.CrlServer.Companion.EMPTY_CRL
import net.corda.testing.node.internal.network.CrlServer.Companion.FORBIDDEN_CRL
import net.corda.testing.node.internal.network.CrlServer.Companion.NODE_CRL
import net.corda.testing.node.internal.network.CrlServer.Companion.withCrlDistPoint
import org.apache.activemq.artemis.api.core.QueueConfiguration
import org.apache.activemq.artemis.api.core.RoutingType
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import java.net.SocketTimeoutException
import java.io.Closeable
import java.security.cert.X509Certificate
import java.time.Duration
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.assertEquals
import java.util.stream.IntStream
@Suppress("LongParameterList")
class CertificateRevocationListNodeTests {
abstract class AbstractServerRevocationTest {
@Rule
@JvmField
val temporaryFolder = TemporaryFolder()
private val portAllocation = incrementalPortAllocation()
private val serverPort = portAllocation.nextPort()
protected val serverPort = portAllocation.nextPort()
private lateinit var crlServer: CrlServer
private lateinit var amqpServer: AMQPServer
private lateinit var amqpClient: AMQPClient
protected lateinit var crlServer: CrlServer
private val amqpClients = ArrayList<AMQPClient>()
private abstract class AbstractNodeConfiguration : NodeConfiguration
protected lateinit var defaultCrlDistPoints: CrlDistPoints
protected abstract class AbstractNodeConfiguration : NodeConfiguration
companion object {
private val unreachableIpCounter = AtomicInteger(1)
private val crlConnectTimeout = Duration.ofMillis(System.getProperty("net.corda.dpcrl.connect.timeout").toLong())
val crlConnectTimeout = 2.seconds
/**
* Use this method to get a unqiue unreachable IP address. Subsequent uses of the same IP for connection timeout testing purposes
* may not work as the OS process may cache the timeout result.
*/
private fun newUnreachableIpAddress(): String {
private fun newUnreachableIpAddress(): NetworkHostAndPort {
check(unreachableIpCounter.get() != 255)
return "10.255.255.${unreachableIpCounter.getAndIncrement()}"
return NetworkHostAndPort("10.255.255", unreachableIpCounter.getAndIncrement())
}
}
@ -85,252 +91,190 @@ class CertificateRevocationListNodeTests {
Crypto.findProvider(BouncyCastleProvider.PROVIDER_NAME)
crlServer = CrlServer(NetworkHostAndPort("localhost", 0))
crlServer.start()
defaultCrlDistPoints = CrlDistPoints(crlServer.hostAndPort)
}
@After
fun tearDown() {
if (::amqpClient.isInitialized) {
amqpClient.close()
}
if (::amqpServer.isInitialized) {
amqpServer.close()
}
amqpClients.parallelStream().forEach(AMQPClient::close)
if (::crlServer.isInitialized) {
crlServer.close()
}
}
@Test(timeout=300_000)
fun `AMQP server connection works and soft fail is enabled`() {
verifyAMQPConnection(
fun `connection succeeds when soft fail is enabled`() {
verifyConnection(
crlCheckSoftFail = true,
expectedConnectStatus = true
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `AMQP server connection works and soft fail is disabled`() {
verifyAMQPConnection(
fun `connection succeeds when soft fail is disabled`() {
verifyConnection(
crlCheckSoftFail = false,
expectedConnectStatus = true
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `AMQP server connection fails when client's certificate is revoked and soft fail is enabled`() {
verifyAMQPConnection(
fun `connection fails when client's certificate is revoked and soft fail is enabled`() {
verifyConnection(
crlCheckSoftFail = true,
revokeClientCert = true,
expectedConnectStatus = false
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `AMQP server connection fails when client's certificate is revoked and soft fail is disabled`() {
verifyAMQPConnection(
fun `connection fails when client's certificate is revoked and soft fail is disabled`() {
verifyConnection(
crlCheckSoftFail = false,
revokeClientCert = true,
expectedConnectStatus = false
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `AMQP server connection fails when servers's certificate is revoked and soft fail is enabled`() {
verifyAMQPConnection(
fun `connection fails when server's certificate is revoked and soft fail is enabled`() {
verifyConnection(
crlCheckSoftFail = true,
revokeServerCert = true,
expectedConnectStatus = false
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `AMQP server connection fails when servers's certificate is revoked and soft fail is disabled`() {
verifyAMQPConnection(
fun `connection fails when server's certificate is revoked and soft fail is disabled`() {
verifyConnection(
crlCheckSoftFail = false,
revokeServerCert = true,
expectedConnectStatus = false
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL cannot be obtained and soft fail is enabled`() {
verifyAMQPConnection(
fun `connection succeeds when CRL cannot be obtained and soft fail is enabled`() {
verifyConnection(
crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/invalid.crl",
expectedConnectStatus = true
clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"),
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `AMQP server connection fails when CRL cannot be obtained and soft fail is disabled`() {
verifyAMQPConnection(
fun `connection fails when CRL cannot be obtained and soft fail is disabled`() {
verifyConnection(
crlCheckSoftFail = false,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/invalid.crl",
expectedConnectStatus = false
clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"),
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL is not defined and soft fail is enabled`() {
verifyAMQPConnection(
fun `connection succeeds when CRL is not defined for node CA cert and soft fail is enabled`() {
verifyConnection(
crlCheckSoftFail = true,
nodeCrlDistPoint = null,
expectedConnectStatus = true
clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null),
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `AMQP server connection fails when CRL is not defined and soft fail is disabled`() {
verifyAMQPConnection(
fun `connection fails when CRL is not defined for node CA cert and soft fail is disabled`() {
verifyConnection(
crlCheckSoftFail = false,
nodeCrlDistPoint = null,
expectedConnectStatus = false
clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null),
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL retrieval is forbidden and soft fail is enabled`() {
verifyAMQPConnection(
fun `connection succeeds when CRL is not defined for TLS cert and soft fail is enabled`() {
verifyConnection(
crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL",
expectedConnectStatus = true
clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null),
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() {
verifyAMQPConnection(
crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout * 3,
expectedConnectStatus = true
fun `connection fails when CRL is not defined for TLS cert and soft fail is disabled`() {
verifyConnection(
crlCheckSoftFail = false,
clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null),
expectedConnectedStatus = false
)
val timeoutExceptions = (amqpServer.softFailExceptions + amqpClient.softFailExceptions)
.map { it.rootCause }
.filterIsInstance<SocketTimeoutException>()
assertThat(timeoutExceptions).isNotEmpty
}
@Test(timeout=300_000)
fun `AMQP server connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() {
verifyAMQPConnection(
fun `connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() {
verifyConnection(
crlCheckSoftFail = true,
sslHandshakeTimeout = crlConnectTimeout * 4,
clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()),
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() {
verifyConnection(
crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout / 2,
expectedConnectStatus = false
clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()),
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `verify CRL algorithms`() {
val crl = crlServer.createRevocationList(
"SHA256withECDSA",
crlServer.rootCa,
EMPTY_CRL,
true,
emptyList()
@Test(timeout = 300_000)
fun `influx of new clients during CRL endpoint downtime does not cause existing connections to drop`() {
val serverCrlSource = CertDistPointCrlSource()
// Start the server and verify the first client has connected
val firstClientConnectionChangeStatus = verifyConnection(
crlCheckSoftFail = true,
crlSource = serverCrlSource,
// In general, N remoting threads will naturally support N-1 new handshaking clients plus one thread for heartbeating with
// existing clients. The trick is to make sure at least N new clients are also supported.
remotingThreads = 2,
expectedConnectedStatus = true
)
// This should pass.
crl.verify(crlServer.rootCa.keyPair.public)
// Try changing the algorithm to EC will fail.
assertThatIllegalArgumentException().isThrownBy {
crlServer.createRevocationList(
"EC",
crlServer.rootCa,
EMPTY_CRL,
true,
emptyList()
// Now simulate the CRL endpoint becoming very slow/unreachable
crlServer.delay = 10.minutes
// And pretend enough time has elapsed that the cached CRLs have expired and need downloading again
serverCrlSource.clearCache()
// Now a bunch of new clients have arrived and want to handshake with the server, which will potentially cause the server's Netty
// threads to be tied up in trying to download the CRLs.
IntStream.range(0, 2).parallel().forEach { clientIndex ->
val (newClient, _) = createAMQPClient(
serverPort,
crlCheckSoftFail = true,
legalName = CordaX500Name("NewClient$clientIndex", "London", "GB"),
crlDistPoints = defaultCrlDistPoints
)
}.withMessage("Unknown signature type requested: EC")
newClient.start()
}
// Make sure there are no further connection change updates, i.e. the first client stays connected throughout this whole saga
assertThat(firstClientConnectionChangeStatus.poll(30, TimeUnit.SECONDS)).isNull()
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with soft fail CRL check`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged
)
}
protected abstract fun verifyConnection(crlCheckSoftFail: Boolean,
crlSource: CertDistPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout),
sslHandshakeTimeout: Duration? = null,
remotingThreads: Int? = null,
clientCrlDistPoints: CrlDistPoints = defaultCrlDistPoints,
revokeClientCert: Boolean = false,
revokeServerCert: Boolean = false,
expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange>
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with hard fail CRL check`() {
verifyArtemisConnection(
crlCheckSoftFail = false,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged
)
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with soft fail CRL check on unavailable URL`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL"
)
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with soft fail CRL check on unreachable URL if CRL timeout is within SSL handshake timeout`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout * 3
)
}
@Test(timeout = 300_000)
fun `Artemis server connection fails with soft fail CRL check on unreachable URL if CRL timeout is not within SSL handshake timeout`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedConnected = false,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout / 2
)
}
@Test(timeout = 300_000)
fun `Artemis server connection fails with hard fail CRL check on unavailable URL`() {
verifyArtemisConnection(
crlCheckSoftFail = false,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Rejected,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL"
)
}
@Test(timeout = 300_000)
fun `Artemis server connection fails with soft fail CRL check on revoked node certificate`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Rejected,
revokedNodeCert = true
)
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with disabled CRL check on revoked node certificate`() {
verifyArtemisConnection(
crlCheckSoftFail = false,
crlCheckArtemisServer = false,
expectedStatus = MessageStatus.Acknowledged,
revokedNodeCert = true
)
}
private fun createAMQPClient(targetPort: Int,
crlCheckSoftFail: Boolean,
legalName: CordaX500Name,
nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
tlsCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$EMPTY_CRL",
maxMessageSize: Int = MAX_MESSAGE_SIZE): X509Certificate {
protected fun createAMQPClient(targetPort: Int,
crlCheckSoftFail: Boolean,
legalName: CordaX500Name,
crlDistPoints: CrlDistPoints): Pair<AMQPClient, X509Certificate> {
val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation
val certificatesDirectory = baseDirectory / "certificates"
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory)
@ -344,31 +288,128 @@ class CertificateRevocationListNodeTests {
doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail
}
clientConfig.configureWithDevSSLCertificate()
val nodeCert = recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, tlsCrlDistPoint)
val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer)
val keyStore = clientConfig.p2pSslOptions.keyStore.get()
val amqpConfig = object : AMQPConfiguration {
override val keyStore = keyStore
override val trustStore = clientConfig.p2pSslOptions.trustStore.get()
override val maxMessageSize: Int = maxMessageSize
override val maxMessageSize: Int = MAX_MESSAGE_SIZE
override val trace: Boolean = true
}
amqpClient = AMQPClient(
val amqpClient = AMQPClient(
listOf(NetworkHostAndPort("localhost", targetPort)),
setOf(CHARLIE_NAME),
amqpConfig,
threadPoolName = legalName.organisation
threadPoolName = legalName.organisation,
distPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout)
)
amqpClients += amqpClient
return Pair(amqpClient, nodeCert)
}
return nodeCert
protected fun AMQPClient.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange> {
val connectionChangeStatus = LinkedBlockingQueue<ConnectionChange>()
onConnection.subscribe { connectionChangeStatus.add(it) }
start()
assertThat(connectionChangeStatus.take().connected).isEqualTo(expectedConnectedStatus)
return connectionChangeStatus
}
protected data class CrlDistPoints(val crlServerAddress: NetworkHostAndPort,
val nodeCa: String? = NODE_CRL,
val tls: String? = EMPTY_CRL) {
private val nodeCaCertCrlDistPoint: String? get() = nodeCa?.let { "http://$crlServerAddress/crl/$it" }
private val tlsCertCrlDistPoint: String? get() = tls?.let { "http://$crlServerAddress/crl/$it" }
fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier,
p2pSslConfiguration: MutualSslConfiguration,
crlServer: CrlServer): X509Certificate {
val nodeKeyStore = signingCertificateStore.get()
val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_CA, nodeKeyStore.entryPassword) }
val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCertCrlDistPoint)
val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) +
nodeKeyStore.query { getCertificateChain(CORDA_CLIENT_CA) }.drop(2)
nodeKeyStore.update {
internal.deleteEntry(CORDA_CLIENT_CA)
}
nodeKeyStore.update {
setPrivateKey(CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword)
}
val sslKeyStore = p2pSslConfiguration.keyStore.get()
val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_TLS, sslKeyStore.entryPassword) }
val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCertCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal)
val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) +
sslKeyStore.query { getCertificateChain(CORDA_CLIENT_TLS) }.drop(3)
sslKeyStore.update {
internal.deleteEntry(CORDA_CLIENT_TLS)
}
sslKeyStore.update {
setPrivateKey(CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword)
}
return newNodeCert
}
}
}
class AMQPServerRevocationTest : AbstractServerRevocationTest() {
private lateinit var amqpServer: AMQPServer
@After
fun shutDown() {
if (::amqpServer.isInitialized) {
amqpServer.close()
}
}
override fun verifyConnection(crlCheckSoftFail: Boolean,
crlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?,
remotingThreads: Int?,
clientCrlDistPoints: CrlDistPoints,
revokeClientCert: Boolean,
revokeServerCert: Boolean,
expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange> {
val serverCert = createAMQPServer(
serverPort,
CHARLIE_NAME,
crlCheckSoftFail,
defaultCrlDistPoints,
crlSource,
sslHandshakeTimeout,
remotingThreads
)
if (revokeServerCert) {
crlServer.revokedNodeCerts.add(serverCert)
}
amqpServer.start()
amqpServer.onReceive.subscribe {
it.complete(true)
}
val (client, clientCert) = createAMQPClient(
serverPort,
crlCheckSoftFail = crlCheckSoftFail,
legalName = ALICE_NAME,
crlDistPoints = clientCrlDistPoints
)
if (revokeClientCert) {
crlServer.revokedNodeCerts.add(clientCert)
}
return client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus)
}
private fun createAMQPServer(port: Int,
legalName: CordaX500Name,
crlCheckSoftFail: Boolean,
nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
tlsCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$EMPTY_CRL",
maxMessageSize: Int = MAX_MESSAGE_SIZE,
sslHandshakeTimeout: Duration? = null): X509Certificate {
crlDistPoints: CrlDistPoints,
distPointCrlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?,
remotingThreads: Int?): X509Certificate {
check(!::amqpServer.isInitialized)
val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation
val certificatesDirectory = baseDirectory / "certificates"
@ -382,92 +423,103 @@ class CertificateRevocationListNodeTests {
doReturn(signingCertificateStore).whenever(it).signingCertificateStore
}
serverConfig.configureWithDevSSLCertificate()
val nodeCert = recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, tlsCrlDistPoint)
val serverCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer)
val keyStore = serverConfig.p2pSslOptions.keyStore.get()
val amqpConfig = object : AMQPConfiguration {
override val keyStore = keyStore
override val trustStore = serverConfig.p2pSslOptions.trustStore.get()
override val revocationConfig = crlCheckSoftFail.toRevocationConfig()
override val maxMessageSize: Int = maxMessageSize
override val maxMessageSize: Int = MAX_MESSAGE_SIZE
override val sslHandshakeTimeout: Duration = sslHandshakeTimeout ?: super.sslHandshakeTimeout
}
amqpServer = AMQPServer("0.0.0.0", port, amqpConfig, threadPoolName = legalName.organisation)
return nodeCert
}
private fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier,
p2pSslConfiguration: MutualSslConfiguration,
nodeCaCrlDistPoint: String?,
tlsCrlDistPoint: String?): X509Certificate {
val nodeKeyStore = signingCertificateStore.get()
val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA, nodeKeyStore.entryPassword) }
val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCrlDistPoint)
val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) +
nodeKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_CA) }.drop(2)
nodeKeyStore.update {
internal.deleteEntry(X509Utilities.CORDA_CLIENT_CA)
}
nodeKeyStore.update {
setPrivateKey(X509Utilities.CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword)
}
val sslKeyStore = p2pSslConfiguration.keyStore.get()
val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_TLS, sslKeyStore.entryPassword) }
val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal)
val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) +
sslKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_TLS) }.drop(3)
sslKeyStore.update {
internal.deleteEntry(X509Utilities.CORDA_CLIENT_TLS)
}
sslKeyStore.update {
setPrivateKey(X509Utilities.CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword)
}
return newNodeCert
}
private fun verifyAMQPConnection(crlCheckSoftFail: Boolean,
nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
revokeServerCert: Boolean = false,
revokeClientCert: Boolean = false,
sslHandshakeTimeout: Duration? = null,
expectedConnectStatus: Boolean) {
val serverCert = createAMQPServer(
serverPort,
CHARLIE_NAME,
crlCheckSoftFail = crlCheckSoftFail,
nodeCrlDistPoint = nodeCrlDistPoint,
sslHandshakeTimeout = sslHandshakeTimeout
amqpServer = AMQPServer(
"0.0.0.0",
port,
amqpConfig,
threadPoolName = legalName.organisation,
distPointCrlSource = distPointCrlSource,
remotingThreads = remotingThreads
)
if (revokeServerCert) {
crlServer.revokedNodeCerts.add(serverCert.serialNumber)
return serverCert
}
}
class ArtemisServerRevocationTest : AbstractServerRevocationTest() {
private lateinit var artemisNode: ArtemisNode
private var crlCheckArtemisServer = true
@After
fun shutDown() {
if (::artemisNode.isInitialized) {
artemisNode.close()
}
amqpServer.start()
amqpServer.onReceive.subscribe {
it.complete(true)
}
val clientCert = createAMQPClient(
}
@Test(timeout = 300_000)
fun `connection succeeds with disabled CRL check on revoked node certificate`() {
crlCheckArtemisServer = false
verifyConnection(
crlCheckSoftFail = false,
revokeClientCert = true,
expectedConnectedStatus = true
)
}
override fun verifyConnection(crlCheckSoftFail: Boolean,
crlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?,
remotingThreads: Int?,
clientCrlDistPoints: CrlDistPoints,
revokeClientCert: Boolean,
revokeServerCert: Boolean,
expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange> {
val (client, clientCert) = createAMQPClient(
serverPort,
crlCheckSoftFail = crlCheckSoftFail,
crlCheckSoftFail = true,
legalName = ALICE_NAME,
nodeCrlDistPoint = nodeCrlDistPoint
crlDistPoints = clientCrlDistPoints
)
if (revokeClientCert) {
crlServer.revokedNodeCerts.add(clientCert.serialNumber)
crlServer.revokedNodeCerts.add(clientCert)
}
val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.start()
val serverConnect = serverConnected.get()
assertThat(serverConnect.connected).isEqualTo(expectedConnectStatus)
val nodeCert = startArtemisNode(
CHARLIE_NAME,
crlCheckSoftFail,
defaultCrlDistPoints,
crlSource,
sslHandshakeTimeout,
remotingThreads
)
if (revokeServerCert) {
crlServer.revokedNodeCerts.add(nodeCert)
}
val queueName = "${P2P_PREFIX}Test"
artemisNode.client.started!!.session.createQueue(
QueueConfiguration(queueName).setRoutingType(RoutingType.ANYCAST).setAddress(queueName).setDurable(true)
)
val clientConnectionChangeStatus = client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus)
if (expectedConnectedStatus) {
val msg = client.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap())
client.write(msg)
assertThat(msg.onComplete.get()).isEqualTo(MessageStatus.Acknowledged)
}
return clientConnectionChangeStatus
}
private fun createArtemisServerAndClient(legalName: CordaX500Name,
crlCheckSoftFail: Boolean,
crlCheckArtemisServer: Boolean,
nodeCrlDistPoint: String,
sslHandshakeTimeout: Duration?): Pair<ArtemisMessagingServer, ArtemisMessagingClient> {
val baseDirectory = temporaryFolder.root.toPath() / "artemis"
private fun startArtemisNode(legalName: CordaX500Name,
crlCheckSoftFail: Boolean,
crlDistPoints: CrlDistPoints,
distPointCrlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?,
remotingThreads: Int?): X509Certificate {
check(!::artemisNode.isInitialized)
val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation
val certificatesDirectory = baseDirectory / "certificates"
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, sslHandshakeTimeout = sslHandshakeTimeout)
@ -483,62 +535,34 @@ class CertificateRevocationListNodeTests {
doReturn(crlCheckArtemisServer).whenever(it).crlCheckArtemisServer
}
artemisConfig.configureWithDevSSLCertificate()
recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, null)
val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer)
val server = ArtemisMessagingServer(
artemisConfig,
artemisConfig.p2pAddress,
MAX_MESSAGE_SIZE,
threadPoolName = "${legalName.organisation}-server",
trace = true
trace = true,
distPointCrlSource = distPointCrlSource,
remotingThreads = remotingThreads
)
val client = ArtemisMessagingClient(
artemisConfig.p2pSslOptions,
artemisConfig.p2pAddress,
MAX_MESSAGE_SIZE,
threadPoolName = "${legalName.organisation}-client",
trace = true
threadPoolName = "${legalName.organisation}-client"
)
server.start()
client.start()
return server to client
val artemisNode = ArtemisNode(server, client)
this.artemisNode = artemisNode
return nodeCert
}
private fun verifyArtemisConnection(crlCheckSoftFail: Boolean,
crlCheckArtemisServer: Boolean,
expectedConnected: Boolean = true,
expectedStatus: MessageStatus? = null,
revokedNodeCert: Boolean = false,
nodeCrlDistPoint: String = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
sslHandshakeTimeout: Duration? = null) {
val queueName = P2P_PREFIX + "Test"
val (artemisServer, artemisClient) = createArtemisServerAndClient(
CHARLIE_NAME,
crlCheckSoftFail,
crlCheckArtemisServer,
nodeCrlDistPoint,
sslHandshakeTimeout
)
artemisServer.use {
artemisClient.started!!.session.createQueue(
QueueConfiguration(queueName).setRoutingType(RoutingType.ANYCAST).setAddress(queueName).setDurable(true)
)
val nodeCert = createAMQPClient(serverPort, true, ALICE_NAME, nodeCrlDistPoint)
if (revokedNodeCert) {
crlServer.revokedNodeCerts.add(nodeCert.serialNumber)
}
val clientConnected = amqpClient.onConnection.toFuture()
amqpClient.start()
val clientConnect = clientConnected.get()
assertThat(clientConnect.connected).isEqualTo(expectedConnected)
if (expectedConnected) {
val msg = amqpClient.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap())
amqpClient.write(msg)
assertEquals(expectedStatus, msg.onComplete.get())
}
artemisClient.stop()
private class ArtemisNode(val server: ArtemisMessagingServer, val client: ArtemisMessagingClient) : Closeable {
override fun close() {
client.stop()
server.close()
}
}
}

View File

@ -4,12 +4,15 @@ import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
import io.netty.channel.EventLoopGroup
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div
import net.corda.core.toFuture
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.node.services.messaging.ArtemisMessagingServer
@ -23,7 +26,9 @@ import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
@ -31,9 +36,6 @@ import net.corda.testing.core.CHARLIE_NAME
import net.corda.testing.core.MAX_MESSAGE_SIZE
import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import org.apache.activemq.artemis.api.core.QueueConfiguration
import org.apache.activemq.artemis.api.core.RoutingType
import org.assertj.core.api.Assertions.assertThatThrownBy
@ -43,7 +45,11 @@ import org.junit.Test
import org.junit.rules.TemporaryFolder
import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit
import javax.net.ssl.*
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLHandshakeException
import javax.net.ssl.SSLParameters
import javax.net.ssl.SSLServerSocket
import javax.net.ssl.SSLSocket
import kotlin.concurrent.thread
import kotlin.test.assertEquals
import kotlin.test.assertTrue
@ -145,15 +151,10 @@ class ProtonWrapperTests {
sslConfig.keyStore.get(true).also { it.registerDevP2pCertificates(ALICE_NAME, rootCa.certificate, intermediateCa) }
sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val context = SSLContext.getInstance("TLS")
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get())
val trustManagers = trustMgrFactory.trustManagers
context.init(keyManagers, trustManagers, newSecureRandom())
@ -344,7 +345,7 @@ class ProtonWrapperTests {
amqpServer.use {
val connectionEvents = amqpServer.onConnection.toBlocking().iterator
amqpServer.start()
val sharedThreads = NioEventLoopGroup()
val sharedThreads = NioEventLoopGroup(DefaultThreadFactory("sharedThreads"))
val amqpClient1 = createSharedThreadsClient(sharedThreads, 0)
val amqpClient2 = createSharedThreadsClient(sharedThreads, 1)
amqpClient1.start()

View File

@ -3,6 +3,8 @@ package net.corda.services.messaging
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.apache.qpid.jms.JmsConnectionFactory
import org.apache.qpid.jms.meta.JmsConnectionInfo
import org.apache.qpid.jms.provider.Provider
@ -24,9 +26,7 @@ import javax.jms.Connection
import javax.jms.Message
import javax.jms.MessageProducer
import javax.jms.Session
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
/**
* Simple AMQP client connecting to broker using JMS.
@ -59,12 +59,8 @@ class SimpleAMQPClient(private val target: NetworkHostAndPort, private val confi
private lateinit var connection: Connection
private fun sslContext(): SSLContext {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()).apply {
init(config.keyStore.get().value.internal, config.keyStore.entryPassword.toCharArray())
}
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
init(config.trustStore.get().value.internal)
}
val keyManagerFactory = keyManagerFactory(config.keyStore.get())
val trustManagerFactory = trustManagerFactory(config.trustStore.get())
val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers

View File

@ -4,7 +4,6 @@ import co.paralleluniverse.fibers.instrument.Retransform
import com.codahale.metrics.MetricRegistry
import com.google.common.collect.MutableClassToInstanceMap
import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.ThreadFactoryBuilder
import com.zaxxer.hikari.pool.HikariPool
import net.corda.common.logging.errorReporting.NodeDatabaseErrors
import net.corda.confidential.SwapIdentitiesFlow
@ -67,6 +66,7 @@ import net.corda.core.toFuture
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.days
import net.corda.core.utilities.millis
import net.corda.core.utilities.minutes
import net.corda.djvm.source.ApiSource
import net.corda.djvm.source.EmptyApi
@ -165,6 +165,7 @@ import net.corda.nodeapi.internal.persistence.RestrictedEntityManager
import net.corda.nodeapi.internal.persistence.SchemaMigration
import net.corda.nodeapi.internal.persistence.contextDatabase
import net.corda.nodeapi.internal.persistence.withoutDatabaseAccess
import net.corda.nodeapi.internal.namedThreadPoolExecutor
import org.apache.activemq.artemis.utils.ReusableLatch
import org.jolokia.jvmagent.JolokiaServer
import org.jolokia.jvmagent.JolokiaServerConfig
@ -180,9 +181,6 @@ import java.util.ArrayList
import java.util.Properties
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeUnit.MINUTES
import java.util.concurrent.TimeUnit.SECONDS
import java.util.function.Consumer
@ -880,13 +878,12 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
}
// Start with 1 thread and scale up to the configured thread pool size if needed
// Parameters of [ThreadPoolExecutor] based on [Executors.newFixedThreadPool]
return ThreadPoolExecutor(
1,
numberOfThreads,
0L,
TimeUnit.MILLISECONDS,
LinkedBlockingQueue<Runnable>(),
ThreadFactoryBuilder().setNameFormat("flow-external-operation-thread").setDaemon(true).build()
return namedThreadPoolExecutor(
corePoolSize = 1,
maxPoolSize = numberOfThreads,
idleKeepAlive = 0.millis,
poolName = "flow-external-operation-thread",
daemonThreads = true
)
}

View File

@ -2,6 +2,7 @@ package net.corda.node.internal
import net.corda.core.contracts.*
import net.corda.core.cordapp.CordappProvider
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.SerializedStateAndRef
import net.corda.core.node.NetworkParameters
import net.corda.core.node.ServicesForResolution
@ -9,8 +10,10 @@ import net.corda.core.node.services.AttachmentStorage
import net.corda.core.node.services.IdentityService
import net.corda.core.node.services.NetworkParametersService
import net.corda.core.node.services.TransactionStorage
import net.corda.core.transactions.BaseTransaction
import net.corda.core.transactions.ContractUpgradeWireTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction
import net.corda.core.transactions.WireTransaction.Companion.resolveStateRefBinaryComponent
@ -26,25 +29,23 @@ data class ServicesForResolutionImpl(
@Throws(TransactionResolutionException::class)
override fun loadState(stateRef: StateRef): TransactionState<*> {
val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash)
return stx.resolveBaseTransaction(this).outputs[stateRef.index]
return toBaseTransaction(stateRef.txhash).outputs[stateRef.index]
}
@Throws(TransactionResolutionException::class)
override fun loadStates(stateRefs: Set<StateRef>): Set<StateAndRef<ContractState>> {
return stateRefs.groupBy { it.txhash }.flatMap {
val stx = validatedTransactions.getTransaction(it.key) ?: throw TransactionResolutionException(it.key)
val baseTx = stx.resolveBaseTransaction(this)
it.value.map { ref -> StateAndRef(baseTx.outputs[ref.index], ref) }
}.toSet()
val baseTxs = HashMap<SecureHash, BaseTransaction>()
return stateRefs.mapTo(LinkedHashSet()) { stateRef ->
val baseTx = baseTxs.computeIfAbsent(stateRef.txhash, ::toBaseTransaction)
StateAndRef(baseTx.outputs[stateRef.index], stateRef)
}
}
@Throws(TransactionResolutionException::class, AttachmentResolutionException::class)
override fun loadContractAttachment(stateRef: StateRef): Attachment {
// We may need to recursively chase transactions if there are notary changes.
fun inner(stateRef: StateRef, forContractClassName: String?): Attachment {
val ctx = validatedTransactions.getTransaction(stateRef.txhash)?.coreTransaction
?: throw TransactionResolutionException(stateRef.txhash)
val ctx = getSignedTransaction(stateRef.txhash).coreTransaction
when (ctx) {
is WireTransaction -> {
val transactionState = ctx.outRef<ContractState>(stateRef.index).state
@ -69,4 +70,10 @@ data class ServicesForResolutionImpl(
}
return inner(stateRef, null)
}
private fun toBaseTransaction(txhash: SecureHash): BaseTransaction = getSignedTransaction(txhash).resolveBaseTransaction(this)
private fun getSignedTransaction(txhash: SecureHash): SignedTransaction {
return validatedTransactions.getTransaction(txhash) ?: throw TransactionResolutionException(txhash)
}
}

View File

@ -135,12 +135,12 @@ class BrokerJaasLoginModule : BaseBrokerJaasLoginModule() {
Pair(ArtemisMessagingComponent.NODE_RPC_USER, listOf(RolePrincipal(NODE_RPC_ROLE)))
}
ArtemisMessagingComponent.PEER_USER -> {
requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." }
val p2pJaasConfig = requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." }
requireTls(certificates)
// This check is redundant as it was performed already during the SSL handshake
CertificateChainCheckPolicy.RootMustMatch.createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates)
CertificateChainCheckPolicy.RevocationCheck(p2pJaasConfig!!.revocationMode)
.createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates)
CertificateChainCheckPolicy.RootMustMatch
.createCheck(p2pJaasConfig.keyStore, p2pJaasConfig.trustStore)
.checkCertificateChain(certificates)
Pair(certificates.first().subjectDN.name, listOf(RolePrincipal(PEER_ROLE)))
}
else -> {

View File

@ -2,17 +2,9 @@ package net.corda.node.internal.artemis
import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.certPathToString
import java.security.KeyStore
import java.security.cert.CertPathValidator
import java.security.cert.CertPathValidatorException
import java.security.cert.CertificateException
import java.security.cert.PKIXBuilderParameters
import java.security.cert.X509CertSelector
sealed class CertificateChainCheckPolicy {
companion object {
@ -92,33 +84,4 @@ sealed class CertificateChainCheckPolicy {
}
}
}
class RevocationCheck(val revocationConfig: RevocationConfig) : CertificateChainCheckPolicy() {
constructor(revocationMode: RevocationConfig.Mode) : this(RevocationConfigImpl(revocationMode))
override fun createCheck(keyStore: KeyStore, trustStore: KeyStore): Check {
return object : Check {
override fun checkCertificateChain(theirChain: Array<java.security.cert.X509Certificate>) {
// Convert javax.security.cert.X509Certificate to java.security.cert.X509Certificate.
val chain = theirChain.map { X509CertificateFactory().generateCertificate(it.encoded.inputStream()) }
log.info("Check Client Certpath:\r\n${certPathToString(chain.toTypedArray())}")
// Drop the last certificate which must be a trusted root (validated by RootMustMatch).
// Assume that there is no more trusted roots (or corresponding public keys) in the remaining chain.
// See PKIXValidator.engineValidate() for reference implementation.
val certPath = X509Utilities.buildCertPath(chain.dropLast(1))
val certPathValidator = CertPathValidator.getInstance("PKIX")
val pkixRevocationChecker = revocationConfig.createPKIXRevocationChecker()
val params = PKIXBuilderParameters(trustStore, X509CertSelector())
params.addCertPathChecker(pkixRevocationChecker)
try {
certPathValidator.validate(certPath, params)
} catch (ex: CertPathValidatorException) {
log.error("Bad certificate path", ex)
throw ex
}
}
}
}
}
}

View File

@ -7,9 +7,16 @@ import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.node.internal.artemis.*
import net.corda.node.internal.artemis.ArtemisBroker
import net.corda.node.internal.artemis.BrokerAddresses
import net.corda.node.internal.artemis.BrokerJaasLoginModule
import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.NODE_P2P_ROLE
import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.PEER_ROLE
import net.corda.node.internal.artemis.NodeJaasConfig
import net.corda.node.internal.artemis.P2PJaasConfig
import net.corda.node.internal.artemis.SecureArtemisConfiguration
import net.corda.node.internal.artemis.UserValidationPlugin
import net.corda.node.internal.artemis.isBindingError
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.utilities.artemis.startSynchronously
import net.corda.nodeapi.internal.AmqpMessageSizeChecksInterceptor
@ -21,7 +28,10 @@ import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.SECURITY_INVALIDATION_INTERVAL
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pAcceptorTcpTransport
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.nodeapi.internal.requireOnDefaultFileSystem
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl
@ -32,9 +42,7 @@ import org.apache.activemq.artemis.core.security.Role
import org.apache.activemq.artemis.core.server.ActiveMQServer
import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl
import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager
import java.io.IOException
import java.lang.Long.max
import java.security.KeyStoreException
import javax.annotation.concurrent.ThreadSafe
import javax.security.auth.login.AppConfigurationEntry
import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED
@ -58,7 +66,9 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
private val maxMessageSize: Int,
private val journalBufferTimeout : Int? = null,
private val threadPoolName: String = "ArtemisServer",
private val trace: Boolean = false) : ArtemisBroker, SingletonSerializeAsToken() {
private val trace: Boolean = false,
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON,
private val remotingThreads: Int? = null) : ArtemisBroker, SingletonSerializeAsToken() {
companion object {
private val log = contextLogger()
}
@ -92,7 +102,7 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
override val started: Boolean
get() = activeMQServer.isStarted
@Throws(IOException::class, AddressBindingException::class, KeyStoreException::class)
@Suppress("ThrowsCount")
private fun configureAndStartServer() {
val artemisConfig = createArtemisConfig()
val securityManager = createArtemisSecurityManager()
@ -132,11 +142,23 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
// The transaction cache is configurable, and drives other cache sizes.
globalMaxSize = max(config.transactionCacheSizeBytes, 10L * maxMessageSize)
val revocationMode = if (config.crlCheckArtemisServer) {
if (config.crlCheckSoftFail) RevocationConfig.Mode.SOFT_FAIL else RevocationConfig.Mode.HARD_FAIL
} else {
RevocationConfig.Mode.OFF
}
val trustManagerFactory = trustManagerFactoryWithRevocation(
config.p2pSslOptions.trustStore.get(),
RevocationConfigImpl(revocationMode),
distPointCrlSource
)
addAcceptorConfiguration(p2pAcceptorTcpTransport(
NetworkHostAndPort(messagingServerAddress.host, messagingServerAddress.port),
config.p2pSslOptions,
trustManagerFactory,
threadPoolName = threadPoolName,
trace = trace
trace = trace,
remotingThreads = remotingThreads
))
// Enable built in message deduplication. Note we still have to do our own as the delayed commits
// and our own definition of commit mean that the built in deduplication cannot remove all duplicates.

View File

@ -10,6 +10,7 @@ import io.netty.handler.ssl.SslHandshakeTimeoutException
import net.corda.core.internal.declaredField
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor
import org.apache.activemq.artemis.api.core.BaseInterceptor
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor
import org.apache.activemq.artemis.core.server.balancing.RedirectHandler
@ -19,14 +20,18 @@ import org.apache.activemq.artemis.spi.core.remoting.Acceptor
import org.apache.activemq.artemis.spi.core.remoting.AcceptorFactory
import org.apache.activemq.artemis.spi.core.remoting.BufferHandler
import org.apache.activemq.artemis.spi.core.remoting.ServerConnectionLifeCycleListener
import org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactoryProvider
import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactoryProvider
import org.apache.activemq.artemis.utils.ConfigurationHelper
import org.apache.activemq.artemis.utils.actors.OrderedExecutor
import java.net.SocketAddress
import java.nio.channels.ClosedChannelException
import java.time.Duration
import java.util.concurrent.Executor
import java.util.concurrent.ScheduledExecutorService
import java.util.regex.Pattern
import javax.net.ssl.SSLEngine
import javax.net.ssl.SSLPeerUnverifiedException
@Suppress("unused") // Used via reflection in ArtemisTcpTransport
class NodeNettyAcceptorFactory : AcceptorFactory {
@ -55,9 +60,16 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
{
companion object {
private val defaultThreadPoolNamePattern = Pattern.compile("""Thread-(\d+) \(activemq-netty-threads\)""")
init {
// Make sure Artemis isn't using another (Open)SSLContextFactory
check(SSLContextFactoryProvider.getSSLContextFactory() is NodeSSLContextFactory)
check(OpenSSLContextFactoryProvider.getOpenSSLContextFactory() is NodeOpenSSLContextFactory)
}
}
private val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "NodeNettyAcceptor", configuration)
private val sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
private val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration)
@Synchronized
@ -71,11 +83,17 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
}
}
@Synchronized
override fun stop() {
super.stop()
sslDelegatedTaskExecutor.shutdown()
}
@Synchronized
override fun getSslHandler(alloc: ByteBufAllocator?, peerHost: String?, peerPort: Int): SslHandler {
applyThreadPoolName()
val engine = super.getSslHandler(alloc, peerHost, peerPort).engine()
val sslHandler = NodeAcceptorSslHandler(engine, trace)
val sslHandler = NodeAcceptorSslHandler(engine, sslDelegatedTaskExecutor, trace)
val handshakeTimeout = configuration[ArtemisTcpTransport.SSL_HANDSHAKE_TIMEOUT_NAME] as Duration?
if (handshakeTimeout != null) {
sslHandler.handshakeTimeoutMillis = handshakeTimeout.toMillis()
@ -95,13 +113,15 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
}
private class NodeAcceptorSslHandler(engine: SSLEngine, private val trace: Boolean) : SslHandler(engine) {
private class NodeAcceptorSslHandler(engine: SSLEngine,
delegatedTaskExecutor: Executor,
private val trace: Boolean) : SslHandler(engine, delegatedTaskExecutor) {
companion object {
private val logger = contextLogger()
}
override fun handlerAdded(ctx: ChannelHandlerContext) {
logHandshake()
logHandshake(ctx.channel().remoteAddress())
super.handlerAdded(ctx)
// Unfortunately NettyAcceptor does not let us add extra child handlers, so we have to add our logger this way.
if (trace) {
@ -109,17 +129,22 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
}
}
private fun logHandshake() {
private fun logHandshake(remoteAddress: SocketAddress) {
val start = System.currentTimeMillis()
handshakeFuture().addListener {
val duration = System.currentTimeMillis() - start
val peer = try {
engine().session.peerPrincipal
} catch (e: SSLPeerUnverifiedException) {
remoteAddress
}
when {
it.isSuccess -> logger.info("SSL handshake completed in ${duration}ms with ${engine().session.peerPrincipal}")
it.isCancelled -> logger.warn("SSL handshake cancelled after ${duration}ms")
it.isSuccess -> logger.info("SSL handshake completed in ${duration}ms with $peer")
it.isCancelled -> logger.warn("SSL handshake cancelled after ${duration}ms with $peer")
else -> when (it.cause()) {
is ClosedChannelException -> logger.warn("SSL handshake closed early after ${duration}ms")
is SslHandshakeTimeoutException -> logger.warn("SSL handshake timed out after ${duration}ms")
else -> logger.warn("SSL handshake failed after ${duration}ms", it.cause())
is ClosedChannelException -> logger.warn("SSL handshake closed early after ${duration}ms with $peer")
is SslHandshakeTimeoutException -> logger.warn("SSL handshake timed out after ${duration}ms with $peer")
else -> logger.warn("SSL handshake failed after ${duration}ms with $peer", it.cause())
}
}
}

View File

@ -0,0 +1,59 @@
package net.corda.node.services.messaging
import io.netty.handler.ssl.SslContext
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.SslProvider
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.TRUST_MANAGER_FACTORY_NAME
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.createAndInitSslContext
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultOpenSSLContextFactory
import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultSSLContextFactory
import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextConfig
import java.nio.file.Paths
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
class NodeSSLContextFactory : DefaultSSLContextFactory() {
override fun getSSLContext(config: SSLContextConfig, additionalOpts: Map<String, Any>): SSLContext {
val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?
return if (trustManagerFactory != null) {
createAndInitSslContext(loadKeyManagerFactory(config), trustManagerFactory)
} else {
super.getSSLContext(config, additionalOpts)
}
}
override fun getPriority(): Int {
// We make sure this factory is the one that's chosen, so any sufficiently large value will do.
return 15
}
}
class NodeOpenSSLContextFactory : DefaultOpenSSLContextFactory() {
override fun getServerSslContext(config: SSLContextConfig, additionalOpts: Map<String, Any>): SslContext {
val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?
return if (trustManagerFactory != null) {
SslContextBuilder
.forServer(loadKeyManagerFactory(config))
.sslProvider(SslProvider.OPENSSL)
.trustManager(trustManagerFactory)
.build()
} else {
super.getServerSslContext(config, additionalOpts)
}
}
override fun getPriority(): Int {
// We make sure this factory is the one that's chosen, so any sufficiently large value will do.
return 15
}
}
private fun loadKeyManagerFactory(config: SSLContextConfig): KeyManagerFactory {
val keyStore = CertificateStore.fromFile(Paths.get(config.keystorePath), config.keystorePassword, config.keystorePassword, false)
return keyManagerFactory(keyStore)
}

View File

@ -30,7 +30,7 @@ internal class RpcBrokerConfiguration(baseDirectory: Path, maxMessageSize: Int,
setDirectories(baseDirectory)
val acceptorConfigurationsSet = mutableSetOf(
rpcAcceptorTcpTransport(address, sslOptions, useSsl)
rpcAcceptorTcpTransport(address, sslOptions, enableSSL = useSsl)
)
adminAddress?.let {
acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration)

View File

@ -0,0 +1 @@
net.corda.node.services.messaging.NodeOpenSSLContextFactory

View File

@ -0,0 +1 @@
net.corda.node.services.messaging.NodeSSLContextFactory

View File

@ -124,7 +124,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
bobNode.internals.disableDBCloseOnStop()
bobNode.database.transaction {
VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, cashIssuer)
VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, cashIssuer, atMostThisManyStates = 10)
}
val alicesFakePaper = aliceNode.database.transaction {
@ -233,7 +233,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
val issuer = bank.ref(1, 2, 3)
bobNode.database.transaction {
VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, issuer)
VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, issuer, atMostThisManyStates = 10)
}
val alicesFakePaper = aliceNode.database.transaction {
fillUpForSeller(false, issuer, alice,

View File

@ -244,7 +244,6 @@ class FlowSoftLocksTests {
100.DOLLARS,
bankNode.services,
thisManyStates,
thisManyStates,
cashIssuer
)
}

View File

@ -20,14 +20,13 @@ import net.corda.finance.*
import net.corda.finance.contracts.CommercialPaper
import net.corda.finance.contracts.Commodity
import net.corda.finance.contracts.DealState
import net.corda.finance.workflows.asset.selection.AbstractCashSelection
import net.corda.finance.contracts.asset.Cash
import net.corda.finance.schemas.CashSchemaV1
import net.corda.finance.schemas.CashSchemaV1.PersistentCashState
import net.corda.finance.schemas.CommercialPaperSchemaV1
import net.corda.finance.test.SampleCashSchemaV2
import net.corda.finance.test.SampleCashSchemaV3
import net.corda.finance.workflows.CommercialPaperUtils
import net.corda.finance.workflows.asset.selection.AbstractCashSelection
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.DatabaseTransaction
@ -197,8 +196,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
}
protected fun consumeCash(amount: Amount<Currency>) = vaultFiller.consumeCash(amount, CHARLIE)
private fun setUpDb(_database: CordaPersistence, delay: Long = 0) {
_database.transaction {
private fun setUpDb(database: CordaPersistence, delay: Long = 0) {
database.transaction {
// create new states
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 10, DUMMY_CASH_ISSUER)
val linearStatesXYZ = vaultFiller.fillWithSomeTestLinearStates(1, "XYZ")
@ -444,7 +444,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.state.data.linearNumber }.sortedBy { it.ref.txhash }.sortedBy { it.ref.index }).isEqualTo(allStates)
}
(1..3).forEach {
repeat(3) {
val newAllStates = vaultService.queryBy<DummyLinearContract.State>(sorting = sorting, criteria = criteria).states
assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates)
assertThat(newAllStates).containsExactlyElementsOf(allStates)
@ -485,7 +485,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.ref.txhash }.sortedByDescending { it.ref.index }).isEqualTo(allStates)
}
(1..3).forEach {
repeat(3) {
val newAllStates = vaultService.queryBy<DummyLinearContract.State>(sorting = sorting, criteria = criteria).states
assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates)
assertThat(newAllStates).containsExactlyElementsOf(allStates)
@ -638,7 +638,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
}
val sorted = results.states.sortedBy { it.ref.toString() }
assertThat(results.states).isEqualTo(sorted)
assertThat(results.states).allSatisfy { !consumed.contains(it.ref.txhash) }
assertThat(results.states).allSatisfy { assertThat(consumed).doesNotContain(it.ref.txhash) }
}
}
@ -1537,7 +1537,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789"))
// count fungible assets
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count)
val countCriteria = VaultCustomQueryCriteria(count)
val fungibleStateCount = vaultService.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
assertThat(fungibleStateCount).isEqualTo(10L)
@ -1563,7 +1563,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
// count fungible assets
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val countCriteria = VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val fungibleStateCount = vaultService.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
assertThat(fungibleStateCount).isEqualTo(10L)
@ -1583,7 +1583,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// UNCONSUMED states (default)
// count fungible assets
val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED)
val countCriteriaUnconsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED)
val fungibleStateCountUnconsumed = vaultService.queryBy<FungibleAsset<*>>(countCriteriaUnconsumed).otherResults.single() as Long
assertThat(fungibleStateCountUnconsumed.toInt()).isEqualTo(10 - cashUpdates.consumed.size + cashUpdates.produced.size)
@ -1598,7 +1598,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// CONSUMED states
// count fungible assets
val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED)
val countCriteriaConsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED)
val fungibleStateCountConsumed = vaultService.queryBy<FungibleAsset<*>>(countCriteriaConsumed).otherResults.single() as Long
assertThat(fungibleStateCountConsumed.toInt()).isEqualTo(cashUpdates.consumed.size)
@ -1622,7 +1622,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val start = TODAY
val end = TODAY.plus(30, ChronoUnit.DAYS)
val recordedBetweenExpression = TimeCondition(
QueryCriteria.TimeInstantType.RECORDED,
TimeInstantType.RECORDED,
ColumnPredicate.Between(start, end))
val criteria = VaultQueryCriteria(timeCondition = recordedBetweenExpression)
val results = vaultService.queryBy<ContractState>(criteria)
@ -1632,7 +1632,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// Future
val startFuture = TODAY.plus(1, ChronoUnit.DAYS)
val recordedBetweenExpressionFuture = TimeCondition(
QueryCriteria.TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end))
TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end))
val criteriaFuture = VaultQueryCriteria(timeCondition = recordedBetweenExpressionFuture)
assertThat(vaultService.queryBy<ContractState>(criteriaFuture).states).isEmpty()
}
@ -1648,7 +1648,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
consumeCash(100.DOLLARS)
val asOfDateTime = TODAY
val consumedAfterExpression = TimeCondition(
QueryCriteria.TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime))
TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime))
val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED,
timeCondition = consumedAfterExpression)
val results = vaultService.queryBy<ContractState>(criteria)
@ -1705,6 +1705,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
}
// pagination: invalid page size
@Suppress("INTEGER_OVERFLOW")
@Test(timeout=300_000)
fun `invalid page size`() {
expectedEx.expect(VaultQueryException::class.java)
@ -1712,8 +1713,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
database.transaction {
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 100, DUMMY_CASH_ISSUER)
@Suppress("EXPECTED_CONDITION")
val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, @Suppress("INTEGER_OVERFLOW") Integer.MAX_VALUE + 1) // overflow = -2147483648
val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, Integer.MAX_VALUE + 1) // overflow = -2147483648
val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL)
vaultService.queryBy<ContractState>(criteria, paging = pagingSpec)
}
@ -1781,9 +1781,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
println("$index : $any")
}
assertThat(results.otherResults.size).isEqualTo(402)
val instants = results.otherResults.filter { it is Instant }.map { it as Instant }
val instants = results.otherResults.filterIsInstance<Instant>()
assertThat(instants).isSorted
val longs = results.otherResults.filter { it is Long }.map { it as Long }
val longs = results.otherResults.filterIsInstance<Long>()
assertThat(longs.size).isEqualTo(201)
assertThat(instants.size).isEqualTo(201)
assertThat(longs.sum()).isEqualTo(20100L)
@ -1911,8 +1911,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
fun `LinearStateQueryCriteria returns empty resultset without errors if there is an empty list after the 'in' clause`() {
database.transaction {
val uid = UniqueIdentifier("999")
vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, uniqueIdentifier = uid)
vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, externalId = "1234")
vaultFiller.fillWithSomeTestLinearStates(txCount = 1, uniqueIdentifier = uid)
vaultFiller.fillWithSomeTestLinearStates(txCount = 1, externalId = "1234")
val uuidCriteria = LinearStateQueryCriteria(uuid = listOf(uid.id))
val externalIdCriteria = LinearStateQueryCriteria(externalId = listOf("1234"))
@ -2061,6 +2061,26 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
}
}
@Test(timeout = 300_000)
fun `unconsumed states which are globally unordered across multiple transactions sorted by custom attribute`() {
val linearNumbers = Array(2) { LongArray(2) }
// Make sure states from the same transaction are not given consecutive linear numbers.
linearNumbers[0][0] = 1L
linearNumbers[0][1] = 3L
linearNumbers[1][0] = 2L
linearNumbers[1][1] = 4L
val results = database.transaction {
vaultFiller.fillWithTestStates(txCount = 2, statesPerTx = 2) { participantsToUse, txIndex, stateIndex ->
DummyLinearContract.State(participants = participantsToUse, linearNumber = linearNumbers[txIndex][stateIndex])
}
val sortColumn = Sort.SortColumn(SortAttribute.Custom(DummyLinearStateSchemaV1.PersistentDummyLinearState::class.java, "linearNumber"))
vaultService.queryBy<DummyLinearContract.State>(VaultQueryCriteria(), sorting = Sort(setOf(sortColumn)))
}
assertThat(results.states.map { it.state.data.linearNumber }).isEqualTo(listOf(1L, 2L, 3L, 4L))
}
@Test(timeout=300_000)
fun `return consumed linear states for a given linear id`() {
database.transaction {
@ -2390,7 +2410,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
services.recordTransactions(commercialPaper2)
val ccyIndex = builder { CommercialPaperSchemaV1.PersistentCommercialPaperState::currency.equal(USD.currencyCode) }
val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex)
val criteria1 = VaultCustomQueryCriteria(ccyIndex)
val result = vaultService.queryBy<CommercialPaper.State>(criteria1)
@ -2433,9 +2453,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val maturityIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::maturity.greaterThanOrEqual(TEST_TX_TIME + 30.days)
val faceValueIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::faceValue.greaterThanOrEqual(10000L)
val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex)
val criteria2 = QueryCriteria.VaultCustomQueryCriteria(maturityIndex)
val criteria3 = QueryCriteria.VaultCustomQueryCriteria(faceValueIndex)
val criteria1 = VaultCustomQueryCriteria(ccyIndex)
val criteria2 = VaultCustomQueryCriteria(maturityIndex)
val criteria3 = VaultCustomQueryCriteria(faceValueIndex)
vaultService.queryBy<CommercialPaper.State>(criteria1.and(criteria3).and(criteria2))
}
@ -2458,8 +2478,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val generalCriteria = VaultQueryCriteria(Vault.StateStatus.ALL)
val results = builder {
val currencyIndex = PersistentCashState::currency.equal(USD.currencyCode)
val quantityIndex = PersistentCashState::pennies.greaterThanOrEqual(10L)
val currencyIndex = CashSchemaV1.PersistentCashState::currency.equal(USD.currencyCode)
val quantityIndex = CashSchemaV1.PersistentCashState::pennies.greaterThanOrEqual(10L)
val customCriteria1 = VaultCustomQueryCriteria(currencyIndex)
val customCriteria2 = VaultCustomQueryCriteria(quantityIndex)
@ -2710,7 +2730,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// Enrich and override QueryCriteria with additional default attributes (such as soft locks)
val enrichedCriteria = VaultQueryCriteria(contractStateTypes = setOf(DealState::class.java), // enrich
softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())),
softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())),
status = Vault.StateStatus.UNCONSUMED) // override
// Sorting
val sortAttribute = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF)
@ -3056,7 +3076,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
assertThat(snapshot.states).hasSize(0)
val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states
this.session.flush()
vaultFiller.consumeLinearStates(states.toList())
vaultFiller.consumeStates(states)
updates
}
@ -3079,7 +3099,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
assertThat(snapshot.states).hasSize(0)
val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states
this.session.flush()
vaultFiller.consumeLinearStates(states.toList())
vaultFiller.consumeStates(states)
updates
}
@ -3102,7 +3122,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
assertThat(snapshot.states).hasSize(0)
val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states
this.session.flush()
vaultFiller.consumeLinearStates(states.toList())
vaultFiller.consumeStates(states)
updates
}

View File

@ -1,30 +1,50 @@
@file:Suppress("UNUSED_PARAMETER")
@file:JvmName("TestUtils")
@file:Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod", "LongParameterList")
package net.corda.testing.core
import net.corda.core.contracts.PartyAndReference
import net.corda.core.contracts.StateRef
import net.corda.core.crypto.*
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignatureScheme
import net.corda.core.crypto.toStringShort
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.toX500Name
import net.corda.core.internal.unspecifiedCountry
import net.corda.core.node.NodeInfo
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.millis
import net.corda.core.utilities.minutes
import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames
import org.bouncycastle.asn1.x509.CRLReason
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.ExtensionsGenerator
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.IssuingDistributionPoint
import org.bouncycastle.cert.jcajce.JcaX509CRLConverter
import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import java.math.BigInteger
import java.net.URI
import java.security.KeyPair
import java.security.PublicKey
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.time.Duration
import java.time.Instant
import java.util.*
import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.fail
@ -109,6 +129,44 @@ fun getTestPartyAndCertificate(name: CordaX500Name, publicKey: PublicKey): Party
return getTestPartyAndCertificate(Party(name, publicKey))
}
fun createCRL(issuer: CertificateAndKeyPair,
revokedCerts: List<X509Certificate>,
issuingDistPoint: URI? = null,
thisUpdate: Instant = Instant.now(),
nextUpdate: Instant = thisUpdate + 5.minutes,
indirect: Boolean = false,
revocationDate: Instant = thisUpdate,
crlReason: Int = CRLReason.keyCompromise,
signatureAlgorithm: String = "SHA256withECDSA"): X509CRL {
val builder = JcaX509v2CRLBuilder(issuer.certificate.subjectX500Principal, Date.from(thisUpdate))
val extensionUtils = JcaX509ExtensionUtils()
builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(issuer.certificate))
// This is required and needs to match the certificate settings with respect to being indirect
builder.addExtension(
Extension.issuingDistributionPoint,
true,
IssuingDistributionPoint(
issuingDistPoint?.let { DistributionPointName(toGeneralNames(it.toString(), GeneralName.uniformResourceIdentifier)) },
indirect,
false
)
)
builder.setNextUpdate(Date.from(nextUpdate))
for (revokedCert in revokedCerts) {
val extensionsGenerator = ExtensionsGenerator()
extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(crlReason))
// Certificate issuer is required for indirect CRL
extensionsGenerator.addExtension(
Extension.certificateIssuer,
true,
GeneralNames(GeneralName(revokedCert.issuerX500Principal.toX500Name()))
)
builder.addCRLEntry(revokedCert.serialNumber, Date.from(revocationDate), extensionsGenerator.generate())
}
val bcProvider = Crypto.findProvider("BC")
val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(bcProvider).build(issuer.keyPair.private)
return JcaX509CRLConverter().setProvider(bcProvider).getCRL(builder.build(signer))
}
private val count = AtomicInteger(0)
/**
@ -188,7 +246,6 @@ fun NodeInfo.singleIdentity(): Party = singleIdentityAndCert().party
* The above will test our expectation that the getWaitingFlows action was executed successfully considering
* that it may take a few hundreds of milliseconds for the flow state machine states to settle.
*/
@Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod")
fun <T> executeTest(
timeout: Duration,
cleanup: (() -> Unit)? = null,

View File

@ -4,30 +4,26 @@ package net.corda.testing.node.internal.network
import net.corda.core.crypto.Crypto
import net.corda.core.internal.CertRole
import net.corda.core.internal.toX500Name
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.days
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.ContentSignerBuilder
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames
import net.corda.nodeapi.internal.crypto.certificateType
import net.corda.nodeapi.internal.crypto.toJca
import org.bouncycastle.asn1.x500.X500Name
import net.corda.testing.core.createCRL
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.DistributionPoint
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.IssuingDistributionPoint
import org.bouncycastle.asn1.x509.ReasonFlags
import org.bouncycastle.cert.jcajce.JcaX509CRLConverter
import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.ServerConnector
import org.eclipse.jetty.server.handler.HandlerCollection
@ -36,11 +32,12 @@ import org.eclipse.jetty.servlet.ServletHolder
import org.glassfish.jersey.server.ResourceConfig
import org.glassfish.jersey.servlet.ServletContainer
import java.io.Closeable
import java.math.BigInteger
import java.net.InetSocketAddress
import java.net.URI
import java.security.KeyPair
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.time.Duration
import java.util.*
import javax.security.auth.x500.X500Principal
import javax.ws.rs.GET
@ -51,7 +48,7 @@ import kotlin.collections.ArrayList
class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
companion object {
private const val SIGNATURE_ALGORITHM = "SHA256withECDSA"
private val logger = contextLogger()
const val NODE_CRL = "node.crl"
const val FORBIDDEN_CRL = "forbidden.crl"
@ -72,8 +69,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
null
)
if (crlDistPoint != null) {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint)))
val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(X500Name.getInstance(it.encoded))) }
val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier))
val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(it.toX500Name())) }
val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames)
builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint)))
}
@ -87,14 +84,17 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
}
}
val revokedNodeCerts: MutableList<BigInteger> = ArrayList()
val revokedIntermediateCerts: MutableList<BigInteger> = ArrayList()
val revokedNodeCerts: MutableList<X509Certificate> = ArrayList()
val revokedIntermediateCerts: MutableList<X509Certificate> = ArrayList()
val rootCa: CertificateAndKeyPair = DEV_ROOT_CA
private lateinit var _intermediateCa: CertificateAndKeyPair
val intermediateCa: CertificateAndKeyPair get() = _intermediateCa
@Volatile
var delay: Duration? = null
val hostAndPort: NetworkHostAndPort
get() = server.connectors.mapNotNull { it as? ServerConnector }
.map { NetworkHostAndPort(it.host, it.localPort) }
@ -106,7 +106,7 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
DEV_INTERMEDIATE_CA.certificate.withCrlDistPoint(rootCa.keyPair, "http://$hostAndPort/crl/$INTERMEDIATE_CRL"),
DEV_INTERMEDIATE_CA.keyPair
)
println("Network management web services started on $hostAndPort")
logger.info("Network management web services started on $hostAndPort")
}
fun replaceNodeCertDistPoint(nodeCaCert: X509Certificate,
@ -115,29 +115,20 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
return nodeCaCert.withCrlDistPoint(intermediateCa.keyPair, nodeCaCrlDistPoint, crlIssuer)
}
fun createRevocationList(signatureAlgorithm: String,
ca: CertificateAndKeyPair,
endpoint: String,
indirect: Boolean,
serialNumbers: List<BigInteger>): X509CRL {
println("Generating CRL for $endpoint")
val builder = JcaX509v2CRLBuilder(ca.certificate.subjectX500Principal, Date(System.currentTimeMillis() - 1.minutes.toMillis()))
val extensionUtils = JcaX509ExtensionUtils()
builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(ca.certificate))
val issuingDistPointName = GeneralName(GeneralName.uniformResourceIdentifier, "http://$hostAndPort/crl/$endpoint")
// This is required and needs to match the certificate settings with respect to being indirect
val issuingDistPoint = IssuingDistributionPoint(DistributionPointName(GeneralNames(issuingDistPointName)), indirect, false)
builder.addExtension(Extension.issuingDistributionPoint, true, issuingDistPoint)
builder.setNextUpdate(Date(System.currentTimeMillis() + 1.seconds.toMillis()))
serialNumbers.forEach {
builder.addCRLEntry(it, Date(System.currentTimeMillis() - 10.minutes.toMillis()), ReasonFlags.certificateHold)
}
val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(Crypto.findProvider("BC")).build(ca.keyPair.private)
return JcaX509CRLConverter().setProvider(Crypto.findProvider("BC")).getCRL(builder.build(signer))
private fun createServerCRL(issuer: CertificateAndKeyPair,
endpoint: String,
indirect: Boolean,
revokedCerts: List<X509Certificate>): X509CRL {
logger.info("Generating CRL for /$endpoint: ${revokedCerts.map { it.serialNumber }}")
return createCRL(
issuer,
revokedCerts,
issuingDistPoint = URI("http://$hostAndPort/crl/$endpoint"),
indirect = indirect
)
}
override fun close() {
println("Shutting down network management web services...")
server.stop()
server.join()
}
@ -159,8 +150,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
@Path(NODE_CRL)
@Produces("application/pkcs7-crl")
fun getNodeCRL(): Response {
return Response.ok(crlServer.createRevocationList(
SIGNATURE_ALGORITHM,
crlServer.delay?.toMillis()?.let(Thread::sleep)
return Response.ok(crlServer.createServerCRL(
crlServer.intermediateCa,
NODE_CRL,
false,
@ -179,8 +170,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
@Path(INTERMEDIATE_CRL)
@Produces("application/pkcs7-crl")
fun getIntermediateCRL(): Response {
return Response.ok(crlServer.createRevocationList(
SIGNATURE_ALGORITHM,
crlServer.delay?.toMillis()?.let(Thread::sleep)
return Response.ok(crlServer.createServerCRL(
crlServer.rootCa,
INTERMEDIATE_CRL,
false,
@ -192,11 +183,11 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
@Path(EMPTY_CRL)
@Produces("application/pkcs7-crl")
fun getEmptyCRL(): Response {
return Response.ok(crlServer.createRevocationList(
SIGNATURE_ALGORITHM,
return Response.ok(crlServer.createServerCRL(
crlServer.rootCa,
EMPTY_CRL,
true, emptyList()
true,
emptyList()
).encoded).build()
}
}

View File

@ -42,6 +42,7 @@ import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.SchemaMigration
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.serialization.internal.amqp.AMQP_ENABLED
import net.corda.testing.core.ALICE_NAME
@ -52,6 +53,8 @@ import java.io.IOException
import java.net.ServerSocket
import java.nio.file.Path
import java.security.KeyPair
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.util.*
import java.util.jar.JarOutputStream
import java.util.jar.Manifest
@ -147,6 +150,12 @@ fun p2pSslOptions(path: Path, name: CordaX500Name = CordaX500Name("MegaCorp", "L
return sslConfig
}
fun fixedCrlSource(crls: Set<X509CRL>): CrlSource {
return object : CrlSource {
override fun fetch(certificate: X509Certificate): Set<X509CRL> = crls
}
}
/** This is the same as the deprecated [WireTransaction] c'tor but avoids the deprecation warning. */
@SuppressWarnings("LongParameterList")
fun createWireTransaction(inputs: List<StateRef>,

View File

@ -1,6 +1,20 @@
@file:Suppress("LongParameterList")
package net.corda.testing.internal.vault
import net.corda.core.contracts.*
import net.corda.core.contracts.Amount
import net.corda.core.contracts.AttachmentConstraint
import net.corda.core.contracts.AutomaticPlaceholderConstraint
import net.corda.core.contracts.BelongsToContract
import net.corda.core.contracts.CommandAndState
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.FungibleAsset
import net.corda.core.contracts.Issued
import net.corda.core.contracts.LinearState
import net.corda.core.contracts.PartyAndReference
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.TransactionState
import net.corda.core.contracts.UniqueIdentifier
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SignatureMetadata
import net.corda.core.identity.AbstractParty
@ -19,9 +33,7 @@ import net.corda.finance.contracts.asset.Cash
import net.corda.finance.contracts.asset.Obligation
import net.corda.finance.contracts.asset.OnLedgerAsset
import net.corda.finance.workflows.asset.CashUtils
import net.corda.testing.contracts.DummyContract
import net.corda.testing.contracts.DummyState
import net.corda.testing.core.DummyCommandData
import net.corda.testing.core.TestIdentity
import net.corda.testing.core.dummyCommand
import net.corda.testing.core.singleIdentity
@ -32,6 +44,7 @@ import java.time.Duration
import java.time.Instant
import java.time.Instant.now
import java.util.*
import kotlin.math.floor
/**
* The service hub should provide at least a key management service and a storage service.
@ -46,7 +59,7 @@ class VaultFiller @JvmOverloads constructor(
private val rngFactory: () -> Random = { Random(0L) }) {
companion object {
fun calculateRandomlySizedAmounts(howMuch: Amount<Currency>, min: Int, max: Int, rng: Random): LongArray {
val numSlots = min + Math.floor(rng.nextDouble() * (max - min)).toInt()
val numSlots = min + floor(rng.nextDouble() * (max - min)).toInt()
val baseSize = howMuch.quantity / numSlots
check(baseSize > 0) { baseSize }
@ -79,31 +92,18 @@ class VaultFiller @JvmOverloads constructor(
issuerServices: ServiceHub = services,
participants: List<AbstractParty> = emptyList(),
includeMe: Boolean = true): Vault<DealState> {
val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey
val me = AnonymousParty(myKey)
val participantsToUse = if (includeMe) participants.plus(me) else participants
val transactions: List<SignedTransaction> = dealIds.map {
// Issue a deal state
val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply {
addOutputState(DummyDealContract.State(ref = it, participants = participantsToUse), DUMMY_DEAL_PROGRAM_ID)
addCommand(dummyCommand())
}
val stx = issuerServices.signInitialTransaction(dummyIssue)
return@map services.addSignature(stx, defaultNotary.publicKey)
return fillWithTestStates(
txCount = dealIds.size,
participants = participants,
includeMe = includeMe,
services = issuerServices
) { participantsToUse, txIndex, _ ->
DummyDealContract.State(ref = dealIds[txIndex], participants = participantsToUse)
}
val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE
services.recordTransactions(statesToRecord, transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<DealState>(i) }
}
return Vault(states)
}
@JvmOverloads
fun fillWithSomeTestLinearStates(numberToCreate: Int,
fun fillWithSomeTestLinearStates(txCount: Int,
externalId: String? = null,
participants: List<AbstractParty> = emptyList(),
uniqueIdentifier: UniqueIdentifier? = null,
@ -113,81 +113,41 @@ class VaultFiller @JvmOverloads constructor(
linearTimestamp: Instant = now(),
constraint: AttachmentConstraint = AutomaticPlaceholderConstraint,
includeMe: Boolean = true): Vault<LinearState> {
val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey
val me = AnonymousParty(myKey)
val issuerKey = defaultNotary.keyPair
val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID)
val participantsToUse = if (includeMe) participants.plus(me) else participants
val transactions: List<SignedTransaction> = (1..numberToCreate).map {
// Issue a Linear state
val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply {
addOutputState(DummyLinearContract.State(
linearId = uniqueIdentifier ?: UniqueIdentifier(externalId),
participants = participantsToUse,
linearString = linearString,
linearNumber = linearNumber,
linearBoolean = linearBoolean,
linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID,
constraint = constraint)
addCommand(dummyCommand())
}
return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata)
return fillWithTestStates(txCount, 1, participants, constraint, includeMe) { participantsToUse, _, _ ->
DummyLinearContract.State(
linearId = uniqueIdentifier ?: UniqueIdentifier(externalId),
participants = participantsToUse,
linearString = linearString,
linearNumber = linearNumber,
linearBoolean = linearBoolean,
linearTimestamp = linearTimestamp
)
}
val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE
services.recordTransactions(statesToRecord, transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<LinearState>(i) }
}
return Vault(states)
}
@JvmOverloads
fun fillWithSomeTestLinearAndDealStates(numberToCreate: Int,
fun fillWithSomeTestLinearAndDealStates(txCount: Int,
externalId: String? = null,
participants: List<AbstractParty> = emptyList(),
linearString: String = "",
linearNumber: Long = 0L,
linearBoolean: Boolean = false,
linearTimestamp: Instant = now()): Vault<LinearState> {
val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey
val me = AnonymousParty(myKey)
val issuerKey = defaultNotary.keyPair
val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID)
val transactions: List<SignedTransaction> = (1..numberToCreate).map {
val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply {
// Issue a Linear state
addOutputState(DummyLinearContract.State(
linearTimestamp: Instant = now()): Vault<ContractState> {
return fillWithTestStates(txCount, 2, participants) { participantsToUse, _, stateIndex ->
when (stateIndex) {
0 -> DummyLinearContract.State(
linearId = UniqueIdentifier(externalId),
participants = participants.plus(me),
participants = participantsToUse,
linearString = linearString,
linearNumber = linearNumber,
linearBoolean = linearBoolean,
linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID)
// Issue a Deal state
addOutputState(DummyDealContract.State(ref = "test ref", participants = participants.plus(me)), DUMMY_DEAL_PROGRAM_ID)
addCommand(dummyCommand())
linearTimestamp = linearTimestamp
)
else -> DummyDealContract.State(ref = "test ref", participants = participantsToUse)
}
return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata)
}
services.recordTransactions(transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<LinearState>(i) }
}
return Vault(states)
}
@JvmOverloads
fun fillWithSomeTestCash(howMuch: Amount<Currency>,
issuerServices: ServiceHub,
thisManyStates: Int,
issuedBy: PartyAndReference,
owner: AbstractParty? = null,
rng: Random? = null,
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT) = fillWithSomeTestCash(howMuch, issuerServices, thisManyStates, thisManyStates, issuedBy, owner, rng, statesToRecord)
/**
* Creates a random set of between (by default) 3 and 10 cash states that add up to the given amount and adds them
* to the vault. This is intended for unit tests. By default the cash is owned by the legal
@ -196,14 +156,15 @@ class VaultFiller @JvmOverloads constructor(
* @param issuerServices service hub of the issuer node, which will be used to sign the transaction.
* @return a vault object that represents the generated states (it will NOT be the full vault from the service hub!).
*/
@JvmOverloads
fun fillWithSomeTestCash(howMuch: Amount<Currency>,
issuerServices: ServiceHub,
atLeastThisManyStates: Int,
atMostThisManyStates: Int,
issuedBy: PartyAndReference,
owner: AbstractParty? = null,
rng: Random? = null,
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault<Cash.State> {
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT,
atMostThisManyStates: Int = atLeastThisManyStates): Vault<Cash.State> {
val amounts = calculateRandomlySizedAmounts(howMuch, atLeastThisManyStates, atMostThisManyStates, rng ?: rngFactory())
// We will allocate one state to one transaction, for simplicities sake.
val cash = Cash()
@ -212,39 +173,46 @@ class VaultFiller @JvmOverloads constructor(
cash.generateIssue(issuance, Amount(pennies, Issued(issuedBy, howMuch.token)), owner ?: services.myInfo.singleIdentity(), altNotary)
return@map issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey)
}
services.recordTransactions(statesToRecord, transactions)
// Get all the StateRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<Cash.State>(i) }
}
return Vault(states)
return recordTransactions(transactions, statesToRecord)
}
/**
* Records a dummy state in the Vault (useful for creating random states when testing vault queries)
*/
fun fillWithDummyState(participants: List<AbstractParty> = listOf(services.myInfo.singleIdentity())) : Vault<DummyState> {
val outputState = TransactionState(
data = DummyState(Random().nextInt(), participants = participants),
contract = DummyContract.PROGRAM_ID,
notary = defaultNotary.party
)
val participantKeys : List<PublicKey> = participants.map { it.owningKey }
val builder = TransactionBuilder()
.addOutputState(outputState)
.addCommand(DummyCommandData, participantKeys)
val stxn = services.signInitialTransaction(builder)
services.recordTransactions(stxn)
return Vault(setOf(stxn.tx.outRef(0)))
fun fillWithDummyState(participants: List<AbstractParty> = listOf(services.myInfo.singleIdentity())): Vault<DummyState> {
return fillWithTestStates(participants = participants) { participantsToUse, _, _ ->
DummyState(Random().nextInt(), participants = participantsToUse)
}
}
/**
* Puts together an issuance transaction for the specified amount that starts out being owned by the given pubkey.
*/
fun generateCommoditiesIssue(tx: TransactionBuilder, amount: Amount<Issued<Commodity>>, owner: AbstractParty, notary: Party)
= OnLedgerAsset.generateIssue(tx, TransactionState(CommodityState(amount, owner), Obligation.PROGRAM_ID, notary), Obligation.Commands.Issue())
fun <T : ContractState> fillWithTestStates(txCount: Int = 1,
statesPerTx: Int = 1,
participants: List<AbstractParty> = emptyList(),
constraint: AttachmentConstraint = AutomaticPlaceholderConstraint,
includeMe: Boolean = true,
services: ServiceHub = this.services,
genOutputState: (participantsToUse: List<AbstractParty>, txIndex: Int, stateIndex: Int) -> T): Vault<T> {
val issuerKey = defaultNotary.keyPair
val signatureMetadata = SignatureMetadata(
services.myInfo.platformVersion,
Crypto.findSignatureScheme(issuerKey.public).schemeNumberID
)
val participantsToUse = if (includeMe) {
participants + AnonymousParty(this.services.myInfo.chooseIdentity().owningKey)
} else {
participants
}
val transactions = Array(txCount) { txIndex ->
val builder = TransactionBuilder(notary = defaultNotary.party)
repeat(statesPerTx) { stateIndex ->
builder.addOutputState(genOutputState(participantsToUse, txIndex, stateIndex), constraint)
}
builder.addCommand(dummyCommand())
services.signInitialTransaction(builder).withAdditionalSignature(issuerKey, signatureMetadata)
}
val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE
return recordTransactions(transactions.asList(), statesToRecord)
}
/**
*
@ -257,13 +225,16 @@ class VaultFiller @JvmOverloads constructor(
val me = AnonymousParty(myKey)
val issuance = TransactionBuilder(null as Party?)
generateCommoditiesIssue(issuance, Amount(amount.quantity, Issued(issuedBy, amount.token)), me, altNotary)
OnLedgerAsset.generateIssue(
issuance,
TransactionState(CommodityState(Amount(amount.quantity, Issued(issuedBy, amount.token)), me), Obligation.PROGRAM_ID, altNotary),
Obligation.Commands.Issue()
)
val transaction = issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey)
services.recordTransactions(transaction)
return Vault(setOf(transaction.tx.outRef(0)))
return recordTransactions(listOf(transaction))
}
private fun <T : LinearState> consume(states: List<StateAndRef<T>>) {
fun consumeStates(states: Iterable<StateAndRef<*>>) {
// Create a txn consuming different contract types
states.forEach {
val builder = TransactionBuilder(notary = altNotary).apply {
@ -300,10 +271,11 @@ class VaultFiller @JvmOverloads constructor(
}
}
fun consumeDeals(dealStates: List<StateAndRef<DealState>>) = consume(dealStates)
fun consumeLinearStates(linearStates: List<StateAndRef<LinearState>>) = consume(linearStates)
fun consumeDeals(dealStates: List<StateAndRef<DealState>>) = consumeStates(dealStates)
fun consumeLinearStates(linearStates: List<StateAndRef<LinearState>>) = consumeStates(linearStates)
fun evolveLinearStates(linearStates: List<StateAndRef<LinearState>>) = consumeAndProduce(linearStates)
fun evolveLinearState(linearState: StateAndRef<LinearState>): StateAndRef<LinearState> = consumeAndProduce(linearState)
/**
* Consume cash, sending any change to the default identity for this node. Only suitable for use in test scenarios,
* where nodes have a default identity.
@ -319,6 +291,16 @@ class VaultFiller @JvmOverloads constructor(
services.recordTransactions(spendTx)
return update.getOrThrow(Duration.ofSeconds(3))
}
private fun <T : ContractState> recordTransactions(transactions: Iterable<SignedTransaction>,
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault<T> {
services.recordTransactions(statesToRecord, transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<T>(i) }
}
return Vault(states)
}
}
@ -344,4 +326,3 @@ data class CommodityState(
override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Obligation.Commands.Move(), copy(owner = newOwner))
}