Merge branch 'release/os/4.9' into shams-4.10-merge-e6a80822

# Conflicts:
#	.github/workflows/check-pr-title.yml
#	.snyk
#	node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt
#	node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt
#	node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
This commit is contained in:
Shams Asari 2023-07-13 10:53:30 +01:00
commit 3a6deeefa7
67 changed files with 1687 additions and 1204 deletions
.github/workflows
.snyk
client/rpc/src/main/kotlin/net/corda/client/rpc
core-tests/src/test/kotlin/net/corda/coretests/crypto/internal
core/src/main/kotlin/net/corda/core/node
node-api-tests/src/test/kotlin/net/corda/nodeapitests/internal/crypto
node-api/src
node
testing
core-test-utils/src/main/kotlin/net/corda/testing/core
node-driver/src/main/kotlin/net/corda/testing/node
test-utils/src/main/kotlin/net/corda/testing/internal

View File

@ -9,6 +9,6 @@ jobs:
steps:
- uses: morrisoncole/pr-lint-action@v1.6.1
with:
title-regex: '^((CORDA|AG|EG|ENT|INFRA|NAAS|ES)-\d+|NOTICK)(.*)'
title-regex: '^((CORDA|AG|EG|ENT|INFRA|ES)-\d+|NOTICK)(.*)'
on-failed-regex-comment: "PR title failed to match regex -> `%regex%`"
repo-token: "${{ secrets.GITHUB_TOKEN }}"

4
.snyk
View File

@ -159,7 +159,7 @@ ignore:
assessment. Liquibase is used to apply the database migration changes.
XML files are used here to define the changes not YAML and therefore
the Corda node itself is not exposed to this deserialisation
vulnerability.
vulnerability.
expires: 2023-07-12T17:00:51.957Z
created: 2022-12-29T17:00:51.970Z
SNYK-JAVA-ORGYAML-3016889:
@ -180,7 +180,7 @@ ignore:
- '*':
reason: >-
H2 console is not enabled for any of the applications we are running.
When it comes to DB connectivity parameters, we do not allow changing
When it comes to DB connectivity parameters, we do not allow changing
them as they are supplied by Corda Node configuration file.
expires: 2023-07-28T11:36:39.068Z
created: 2022-12-29T11:36:39.089Z

View File

@ -1,5 +1,6 @@
package net.corda.client.rpc
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.ReconnectingCordaRPCOps
import net.corda.client.rpc.internal.SerializationEnvironmentHelper
@ -52,7 +53,7 @@ class CordaRPCConnection private constructor(
sslConfiguration: ClientRpcSslOptions? = null,
classLoader: ClassLoader? = null
): CordaRPCConnection {
val observersPool: ExecutorService = Executors.newCachedThreadPool()
val observersPool: ExecutorService = Executors.newCachedThreadPool(DefaultThreadFactory("RPCObserver"))
return CordaRPCConnection(null, observersPool, ReconnectingCordaRPCOps(
addresses,
username,

View File

@ -1,5 +1,6 @@
package net.corda.client.rpc.internal
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.client.rpc.ConnectionFailureException
import net.corda.client.rpc.CordaRPCClient
import net.corda.client.rpc.CordaRPCClientConfiguration
@ -99,7 +100,8 @@ class ReconnectingCordaRPCOps private constructor(
ErrorInterceptingHandler(reconnectingRPCConnection)) as CordaRPCOps
}
}
private val retryFlowsPool = Executors.newScheduledThreadPool(1)
private val retryFlowsPool = Executors.newScheduledThreadPool(1, DefaultThreadFactory("FlowRetry"))
/**
* This function runs a flow and retries until it completes successfully.
*

View File

@ -0,0 +1,29 @@
package net.corda.coretests.crypto.internal
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.testing.core.createCRL
import org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import org.junit.Test
class ProviderMapTest {
// https://github.com/corda/corda/pull/3997
@Test(timeout = 300_000)
fun `verify CRL algorithms`() {
val crl = createCRL(
issuer = DEV_ROOT_CA,
revokedCerts = emptyList(),
signatureAlgorithm = "SHA256withECDSA"
)
// This should pass.
crl.verify(DEV_ROOT_CA.keyPair.public)
// Try changing the algorithm to EC will fail.
assertThatIllegalArgumentException().isThrownBy {
createCRL(
issuer = DEV_ROOT_CA,
revokedCerts = emptyList(),
signatureAlgorithm = "EC"
)
}.withMessage("Unknown signature type requested: EC")
}
}

View File

@ -65,7 +65,7 @@ interface ServicesForResolution {
/**
* Given a [Set] of [StateRef]'s loads the referenced transaction and looks up the specified output [ContractState].
*
* @throws TransactionResolutionException if [stateRef] points to a non-existent transaction.
* @throws TransactionResolutionException if any of the [stateRefs] point to a non-existent transaction.
*/
// TODO: future implementation to use a Vault state ref -> contract state BLOB table and perform single query bulk load
// as the existing transaction store will become encrypted at some point

View File

@ -1,3 +1,5 @@
@file:Suppress("LongParameterList")
package net.corda.core.node.services
import co.paralleluniverse.fibers.Suspendable
@ -197,8 +199,7 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
* 4) Status types used in this query: [StateStatus.UNCONSUMED], [StateStatus.CONSUMED], [StateStatus.ALL].
* 5) Other results as a [List] of any type (eg. aggregate function results with/without group by).
*
* Note: currently otherResults are used only for Aggregate Functions (in which case, the states and statesMetadata
* results will be empty).
* Note: currently [otherResults] is used only for aggregate functions (in which case, [states] and [statesMetadata] will be empty).
*/
@CordaSerializable
data class Page<out T : ContractState>(val states: List<StateAndRef<T>>,
@ -213,11 +214,11 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
val contractStateClassName: String,
val recordedTime: Instant,
val consumedTime: Instant?,
val status: Vault.StateStatus,
val status: StateStatus,
val notary: AbstractParty?,
val lockId: String?,
val lockUpdateTime: Instant?,
val relevancyStatus: Vault.RelevancyStatus? = null,
val relevancyStatus: RelevancyStatus? = null,
val constraintInfo: ConstraintInfo? = null
) {
fun copy(
@ -225,7 +226,7 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
contractStateClassName: String = this.contractStateClassName,
recordedTime: Instant = this.recordedTime,
consumedTime: Instant? = this.consumedTime,
status: Vault.StateStatus = this.status,
status: StateStatus = this.status,
notary: AbstractParty? = this.notary,
lockId: String? = this.lockId,
lockUpdateTime: Instant? = this.lockUpdateTime
@ -237,11 +238,11 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
contractStateClassName: String = this.contractStateClassName,
recordedTime: Instant = this.recordedTime,
consumedTime: Instant? = this.consumedTime,
status: Vault.StateStatus = this.status,
status: StateStatus = this.status,
notary: AbstractParty? = this.notary,
lockId: String? = this.lockId,
lockUpdateTime: Instant? = this.lockUpdateTime,
relevancyStatus: Vault.RelevancyStatus?
relevancyStatus: RelevancyStatus?
): StateMetadata {
return StateMetadata(ref, contractStateClassName, recordedTime, consumedTime, status, notary, lockId, lockUpdateTime, relevancyStatus, ConstraintInfo(AlwaysAcceptAttachmentConstraint))
}
@ -249,9 +250,9 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
companion object {
@Deprecated("No longer used. The vault does not emit empty updates")
val NoUpdate = Update(emptySet(), emptySet(), type = Vault.UpdateType.GENERAL, references = emptySet())
val NoUpdate = Update(emptySet(), emptySet(), type = UpdateType.GENERAL, references = emptySet())
@Deprecated("No longer used. The vault does not emit empty updates")
val NoNotaryUpdate = Vault.Update(emptySet(), emptySet(), type = Vault.UpdateType.NOTARY_CHANGE, references = emptySet())
val NoNotaryUpdate = Update(emptySet(), emptySet(), type = UpdateType.NOTARY_CHANGE, references = emptySet())
}
}
@ -302,7 +303,7 @@ interface VaultService {
fun whenConsumed(ref: StateRef): CordaFuture<Vault.Update<ContractState>> {
val query = QueryCriteria.VaultQueryCriteria(
stateRefs = listOf(ref),
status = Vault.StateStatus.CONSUMED
status = StateStatus.CONSUMED
)
val result = trackBy<ContractState>(query)
val snapshot = result.snapshot.states
@ -358,8 +359,8 @@ interface VaultService {
/**
* Helper function to determine spendable states and soft locking them.
* Currently performance will be worse than for the hand optimised version in
* [Cash.unconsumedCashStatesForSpending]. However, this is fully generic and can operate with custom [FungibleState]
* and [FungibleAsset] states.
* [net.corda.finance.workflows.asset.selection.AbstractCashSelection.unconsumedCashStatesForSpending]. However, this is fully generic
* and can operate with custom [FungibleState] and [FungibleAsset] states.
* @param lockId The [FlowLogic.runId]'s [UUID] of the current flow used to soft lock the states.
* @param eligibleStatesQuery A custom query object that selects down to the appropriate subset of all states of the
* [contractStateType]. e.g. by selecting on account, issuer, etc. The query is internally augmented with the

View File

@ -21,14 +21,29 @@ import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.days
import net.corda.core.utilities.hours
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
import net.corda.coretesting.internal.NettyTestClient
import net.corda.coretesting.internal.NettyTestHandler
import net.corda.coretesting.internal.NettyTestServer
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_IDENTITY_SIGNATURE_SCHEME
import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME
import net.corda.nodeapi.internal.crypto.checkValidity
import net.corda.nodeapi.internal.crypto.getSupportedKey
import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore
import net.corda.nodeapi.internal.crypto.save
import net.corda.nodeapi.internal.crypto.toBc
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.crypto.x509Certificates
import net.corda.nodeapi.internal.installDevNodeCaCertPath
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
import net.corda.serialization.internal.AllWhitelist
import net.corda.serialization.internal.SerializationContextImpl
import net.corda.serialization.internal.SerializationFactoryImpl
@ -37,25 +52,16 @@ import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.TestIdentity
import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.coretesting.internal.NettyTestClient
import net.corda.coretesting.internal.NettyTestHandler
import net.corda.coretesting.internal.NettyTestServer
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.checkValidity
import net.corda.nodeapi.internal.crypto.getSupportedKey
import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore
import net.corda.nodeapi.internal.crypto.save
import net.corda.nodeapi.internal.crypto.toBc
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.crypto.x509Certificates
import net.corda.testing.internal.IS_OPENJ9
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.i2p.crypto.eddsa.EdDSAPrivateKey
import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.asn1.x509.*
import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.bouncycastle.jcajce.provider.asymmetric.edec.BCEdDSAPrivateKey
import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey
import org.junit.Assume
@ -74,10 +80,19 @@ import java.security.PrivateKey
import java.security.cert.CertPath
import java.security.cert.X509Certificate
import java.util.*
import javax.net.ssl.*
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLParameters
import javax.net.ssl.SSLServerSocket
import javax.net.ssl.SSLSocket
import javax.security.auth.x500.X500Principal
import kotlin.concurrent.thread
import kotlin.test.*
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue
import kotlin.test.fail
class X509UtilitiesTest {
private companion object {
@ -295,15 +310,10 @@ class X509UtilitiesTest {
sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val context = SSLContext.getInstance("TLS")
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get())
val trustManagers = trustMgrFactory.trustManagers
context.init(keyManagers, trustManagers, newSecureRandom())
@ -388,15 +398,8 @@ class X509UtilitiesTest {
sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore)
val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val trustManagerFactory = trustManagerFactory(sslConfig.trustStore.get())
val sslServerContext = SslContextBuilder
.forServer(keyManagerFactory)

View File

@ -42,8 +42,8 @@ class ArtemisMessagingClient(private val config: MutualSslConfiguration,
override fun start(): Started = synchronized(this) {
check(started == null) { "start can't be called twice" }
val tcpTransport = p2pConnectorTcpTransport(serverAddress, config, threadPoolName = threadPoolName, trace = trace)
val backupTransports = backupServerAddressPool.map {
p2pConnectorTcpTransport(it, config, threadPoolName = threadPoolName, trace = trace)
val backupTransports = backupServerAddressPool.mapIndexed { index, address ->
p2pConnectorTcpTransport(address, config, threadPoolName = "$threadPoolName-backup${index+1}", trace = trace)
}
log.info("Connecting to message broker: $serverAddress")

View File

@ -1,16 +1,18 @@
@file:Suppress("LongParameterList")
package net.corda.nodeapi.internal
import net.corda.core.messaging.ClientRpcSslOptions
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.BrokerRpcSslOptions
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.config.SslConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants
import java.nio.file.Path
import javax.net.ssl.TrustManagerFactory
@Suppress("LongParameterList")
class ArtemisTcpTransport {
@ -23,6 +25,7 @@ class ArtemisTcpTransport {
val TLS_VERSIONS = listOf("TLSv1.2")
const val SSL_HANDSHAKE_TIMEOUT_NAME = "Corda-SSLHandshakeTimeout"
const val TRUST_MANAGER_FACTORY_NAME = "Corda-TrustManagerFactory"
const val TRACE_NAME = "Corda-Trace"
const val THREAD_POOL_NAME_NAME = "Corda-ThreadPoolName"
@ -30,7 +33,6 @@ class ArtemisTcpTransport {
// Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop.
// It does not use AMQP messages for its own messages e.g. topology and heartbeats.
private const val P2P_PROTOCOLS = "CORE,AMQP"
private const val RPC_PROTOCOLS = "CORE"
private fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort, protocols: String) = mapOf(
@ -39,46 +41,35 @@ class ArtemisTcpTransport {
TransportConstants.PORT_PROP_NAME to hostAndPort.port,
TransportConstants.PROTOCOLS_PROP_NAME to protocols,
TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null),
TransportConstants.REMOTING_THREADS_PROPNAME to (if (nodeSerializationEnv != null) -1 else 1),
// turn off direct delivery in Artemis - this is latency optimisation that can lead to
//hick-ups under high load (CORDA-1336)
TransportConstants.DIRECT_DELIVER to false)
private val defaultSSLOptions = mapOf(
TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","),
TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to TLS_VERSIONS.joinToString(","))
private fun SslConfiguration.addToTransportOptions(options: MutableMap<String, Any>) {
if (keyStore != null || trustStore != null) {
options[TransportConstants.SSL_ENABLED_PROP_NAME] = true
options[TransportConstants.NEED_CLIENT_AUTH_PROP_NAME] = true
}
keyStore?.let {
with (it) {
path.requireOnDefaultFileSystem()
options.putAll(get().toKeyStoreTransportOptions(path))
options[TransportConstants.KEYSTORE_TYPE_PROP_NAME] = "JKS"
options[TransportConstants.KEYSTORE_PATH_PROP_NAME] = path
options[TransportConstants.KEYSTORE_PASSWORD_PROP_NAME] = get().password
}
}
trustStore?.let {
with (it) {
path.requireOnDefaultFileSystem()
options.putAll(get().toTrustStoreTransportOptions(path))
options[TransportConstants.TRUSTSTORE_TYPE_PROP_NAME] = "JKS"
options[TransportConstants.TRUSTSTORE_PATH_PROP_NAME] = path
options[TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME] = get().password
}
}
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
options[SSL_HANDSHAKE_TIMEOUT_NAME] = handshakeTimeout ?: DEFAULT_SSL_HANDSHAKE_TIMEOUT
}
private fun CertificateStore.toKeyStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.KEYSTORE_TYPE_PROP_NAME to "JKS",
TransportConstants.KEYSTORE_PATH_PROP_NAME to path,
TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun CertificateStore.toTrustStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to "JKS",
TransportConstants.TRUSTSTORE_PATH_PROP_NAME to path,
TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun ClientRpcSslOptions.toTransportOptions() = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to trustStoreProvider,
@ -94,76 +85,110 @@ class ArtemisTcpTransport {
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: MutualSslConfiguration?,
trustManagerFactory: TrustManagerFactory? = config?.trustStore?.get()?.let(::trustManagerFactory),
enableSSL: Boolean = true,
threadPoolName: String = "P2PServer",
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (enableSSL) {
config?.addToTransportOptions(options)
}
return createAcceptorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace)
return createAcceptorTransport(
hostAndPort,
P2P_PROTOCOLS,
options,
trustManagerFactory,
enableSSL,
threadPoolName,
trace,
remotingThreads
)
}
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
config: MutualSslConfiguration?,
enableSSL: Boolean = true,
threadPoolName: String = "P2PClient",
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (enableSSL) {
config?.addToTransportOptions(options)
}
return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace)
return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace, remotingThreads)
}
fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: BrokerRpcSslOptions?,
enableSSL: Boolean = true,
trace: Boolean = false): TransportConfiguration {
threadPoolName: String = "RPCServer",
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) {
config.keyStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions())
}
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCServer", trace)
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, null, enableSSL, threadPoolName, trace, remotingThreads)
}
fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
config: ClientRpcSslOptions?,
enableSSL: Boolean = true,
trace: Boolean = false): TransportConfiguration {
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) {
config.trustStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions())
}
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace, remotingThreads)
}
fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort,
config: SslConfiguration,
threadPoolName: String = "Internal-RPCClient",
trace: Boolean = false): TransportConfiguration {
val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, threadPoolName, trace, null)
}
fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: SslConfiguration,
trace: Boolean = false): TransportConfiguration {
threadPoolName: String = "Internal-RPCServer",
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options)
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCServer", trace)
return createAcceptorTransport(
hostAndPort,
RPC_PROTOCOLS,
options,
trustManagerFactory(requireNotNull(config.trustStore).get()),
true,
threadPoolName,
trace,
remotingThreads
)
}
private fun createAcceptorTransport(hostAndPort: NetworkHostAndPort,
protocols: String,
options: MutableMap<String, Any>,
trustManagerFactory: TrustManagerFactory?,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean): TransportConfiguration {
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
// Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections
options[TransportConstants.HANDSHAKE_TIMEOUT] = 0
if (trustManagerFactory != null) {
// NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use
// more customised instances which use our revocation checkers, so we pass them in, to be picked up by Node(Open)SSLContextFactory.
options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory
}
return createTransport(
"net.corda.node.services.messaging.NodeNettyAcceptorFactory",
hostAndPort,
@ -171,7 +196,8 @@ class ArtemisTcpTransport {
options,
enableSSL,
threadPoolName,
trace
trace,
remotingThreads
)
}
@ -180,15 +206,21 @@ class ArtemisTcpTransport {
options: MutableMap<String, Any>,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean): TransportConfiguration {
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
if (enableSSL) {
// This is required to stop Client checking URL address vs. Server provided certificate
options[TransportConstants.VERIFY_HOST_PROP_NAME] = false
}
return createTransport(
NodeNettyConnectorFactory::class.java.name,
CordaNettyConnectorFactory::class.java.name,
hostAndPort,
protocols,
options,
enableSSL,
threadPoolName,
trace
trace,
remotingThreads
)
}
@ -198,13 +230,15 @@ class ArtemisTcpTransport {
options: MutableMap<String, Any>,
enableSSL: Boolean,
threadPoolName: String,
trace: Boolean): TransportConfiguration {
trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
options += defaultArtemisOptions(hostAndPort, protocols)
if (enableSSL) {
options += defaultSSLOptions
// This is required to stop Client checking URL address vs. Server provided certificate
options[TransportConstants.VERIFY_HOST_PROP_NAME] = false
options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",")
options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",")
}
// By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357)
options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1
options[THREAD_POOL_NAME_NAME] = threadPoolName
options[TRACE_NAME] = trace
return TransportConfiguration(className, options)

View File

@ -1,8 +1,14 @@
@file:JvmName("ArtemisUtils")
package net.corda.nodeapi.internal
import net.corda.core.internal.declaredField
import org.apache.activemq.artemis.utils.actors.ProcessorBase
import java.nio.file.FileSystems
import java.nio.file.Path
import java.util.concurrent.Executor
import java.util.concurrent.ThreadFactory
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.atomic.AtomicInteger
/**
* Require that the [Path] is on a default file system, and therefore is one that Artemis is willing to use.
@ -16,3 +22,29 @@ fun requireMessageSize(messageSize: Int, limit: Int) {
require(messageSize <= limit) { "Message exceeds maxMessageSize network parameter, maxMessageSize: [$limit], message size: [$messageSize]" }
}
val Executor.rootExecutor: Executor get() {
var executor: Executor = this
while (executor is ProcessorBase<*>) {
executor = executor.declaredField<Executor>("delegate").value
}
return executor
}
fun Executor.setThreadPoolName(threadPoolName: String) {
(rootExecutor as? ThreadPoolExecutor)?.let { it.threadFactory = NamedThreadFactory(threadPoolName, it.threadFactory) }
}
private class NamedThreadFactory(poolName: String, private val delegate: ThreadFactory) : ThreadFactory {
companion object {
private val poolId = AtomicInteger(0)
}
private val prefix = "$poolName-${poolId.incrementAndGet()}-"
private val nextId = AtomicInteger(0)
override fun newThread(r: Runnable): Thread {
val thread = delegate.newThread(r)
thread.name = "$prefix${nextId.incrementAndGet()}"
return thread
}
}

View File

@ -14,15 +14,16 @@ import org.apache.activemq.artemis.utils.ConfigurationHelper
import java.util.concurrent.Executor
import java.util.concurrent.ScheduledExecutorService
class NodeNettyConnectorFactory : ConnectorFactory {
class CordaNettyConnectorFactory : ConnectorFactory {
override fun createConnector(configuration: MutableMap<String, Any>?,
handler: BufferHandler?,
listener: ClientConnectionLifeCycleListener?,
closeExecutor: Executor?,
threadPool: Executor?,
scheduledThreadPool: ScheduledExecutorService?,
closeExecutor: Executor,
threadPool: Executor,
scheduledThreadPool: ScheduledExecutorService,
protocolManager: ClientProtocolManager?): Connector {
val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Connector", configuration)
setThreadPoolName(threadPool, closeExecutor, scheduledThreadPool, threadPoolName)
val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration)
return NettyConnector(
configuration,
@ -31,7 +32,7 @@ class NodeNettyConnectorFactory : ConnectorFactory {
closeExecutor,
threadPool,
scheduledThreadPool,
MyClientProtocolManager(threadPoolName, trace)
MyClientProtocolManager("$threadPoolName-netty", trace)
)
}
@ -39,6 +40,17 @@ class NodeNettyConnectorFactory : ConnectorFactory {
override fun getDefaults(): Map<String?, Any?> = NettyConnector.DEFAULT_CONFIG
private fun setThreadPoolName(threadPool: Executor, closeExecutor: Executor, scheduledThreadPool: ScheduledExecutorService, name: String) {
threadPool.setThreadPoolName("$name-artemis")
// Artemis will actually wrap the same backing Executor to create multiple "OrderedExecutors". In this scenerio both the threadPool
// and the closeExecutor are the same when it comes to the pool names. If however they are different then given them separate names.
if (threadPool.rootExecutor !== closeExecutor.rootExecutor) {
closeExecutor.setThreadPoolName("$name-artemis-closer")
}
// The scheduler is separate
scheduledThreadPool.setThreadPoolName("$name-artemis-scheduler")
}
private class MyClientProtocolManager(private val threadPoolName: String, private val trace: Boolean) : ActiveMQClientProtocolManager() {
override fun addChannelHandlers(pipeline: ChannelPipeline) {

View File

@ -0,0 +1,32 @@
@file:Suppress("LongParameterList", "MagicNumber")
package net.corda.nodeapi.internal
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.utilities.seconds
import java.time.Duration
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
/**
* Creates a [ThreadPoolExecutor] which will use a maximum of [maxPoolSize] threads at any given time and will by default idle down to 0
* threads.
*/
fun namedThreadPoolExecutor(maxPoolSize: Int,
corePoolSize: Int = 0,
idleKeepAlive: Duration = 30.seconds,
workQueue: BlockingQueue<Runnable> = LinkedBlockingQueue(),
poolName: String = "pool",
daemonThreads: Boolean = false,
threadPriority: Int = Thread.NORM_PRIORITY): ThreadPoolExecutor {
return ThreadPoolExecutor(
corePoolSize,
maxPoolSize,
idleKeepAlive.toNanos(),
TimeUnit.NANOSECONDS,
workQueue,
DefaultThreadFactory(poolName, daemonThreads, threadPriority)
)
}

View File

@ -22,6 +22,7 @@ import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
@ -31,6 +32,7 @@ import org.apache.activemq.artemis.api.core.client.ClientSession
import org.slf4j.MDC
import rx.Subscription
import java.time.Duration
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.ScheduledFuture
@ -53,7 +55,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
maxMessageSize: Int,
revocationConfig: RevocationConfig,
enableSNI: Boolean,
private val artemisMessageClientFactory: () -> ArtemisSessionProvider,
private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider,
private val bridgeMetricsService: BridgeMetricsService? = null,
trace: Boolean,
sslHandshakeTimeout: Duration?,
@ -78,9 +80,11 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig,useOpenSSL, enableSNI, trace = trace, _sslHandshakeTimeout = sslHandshakeTimeout)
private var sharedEventLoopGroup: EventLoopGroup? = null
private var sslDelegatedTaskExecutor: ExecutorService? = null
private var artemis: ArtemisSessionProvider? = null
companion object {
private val log = contextLogger()
private const val CORDA_NUM_BRIDGE_THREADS_PROP_NAME = "net.corda.nodeapi.amqpbridgemanager.NumBridgeThreads"
@ -97,18 +101,11 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
* however Artemis and the remote Corda instanced will deduplicate these messages.
*/
@Suppress("TooManyFunctions")
private class AMQPBridge(val sourceX500Name: String,
val queueName: String,
val targets: List<NetworkHostAndPort>,
val legalNames: Set<CordaX500Name>,
private val amqpConfig: AMQPConfiguration,
sharedEventGroup: EventLoopGroup,
private val artemis: ArtemisSessionProvider,
private val bridgeMetricsService: BridgeMetricsService?,
private val bridgeConnectionTTLSeconds: Int) {
companion object {
private val log = contextLogger()
}
private inner class AMQPBridge(val sourceX500Name: String,
val queueName: String,
val targets: List<NetworkHostAndPort>,
val allowedRemoteLegalNames: Set<CordaX500Name>,
private val amqpConfig: AMQPConfiguration) {
private fun withMDC(block: () -> Unit) {
val oldMDC = MDC.getCopyOfContextMap() ?: emptyMap<String, String>()
@ -116,7 +113,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
MDC.put("queueName", queueName)
MDC.put("source", amqpConfig.sourceX500Name)
MDC.put("targets", targets.joinToString(separator = ";") { it.toString() })
MDC.put("legalNames", legalNames.joinToString(separator = ";") { it.toString() })
MDC.put("allowedRemoteLegalNames", allowedRemoteLegalNames.joinToString(separator = ";") { it.toString() })
MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString())
block()
} finally {
@ -134,13 +131,18 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) }
val amqpClient = AMQPClient(targets, legalNames, amqpConfig, sharedThreadPool = sharedEventGroup)
val amqpClient = AMQPClient(
targets,
allowedRemoteLegalNames,
amqpConfig,
AMQPClient.NettyThreading.Shared(sharedEventLoopGroup!!, sslDelegatedTaskExecutor!!)
)
private var session: ClientSession? = null
private var consumer: ClientConsumer? = null
private var connectedSubscription: Subscription? = null
@Volatile
private var messagesReceived: Boolean = false
private val eventLoop: EventLoop = sharedEventGroup.next()
private val eventLoop: EventLoop = sharedEventLoopGroup!!.next()
private var artemisState: ArtemisState = ArtemisState.STOPPED
set(value) {
logDebugWithMDC { "State change $field to $value" }
@ -152,32 +154,9 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private var scheduledExecutorService: ScheduledExecutorService
= Executors.newSingleThreadScheduledExecutor(ThreadFactoryBuilder().setNameFormat("bridge-connection-reset-%d").build())
@Suppress("ClassNaming")
private sealed class ArtemisState {
object STARTING : ArtemisState()
data class STARTED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
object CHECKING : ArtemisState()
object RESTARTED : ArtemisState()
object RECEIVING : ArtemisState()
object AMQP_STOPPED : ArtemisState()
object AMQP_STARTING : ArtemisState()
object AMQP_STARTED : ArtemisState()
object AMQP_RESTARTED : ArtemisState()
object STOPPING : ArtemisState()
object STOPPED : ArtemisState()
data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
open val pending: ScheduledFuture<Unit>? = null
override fun toString(): String = javaClass.simpleName
}
private fun artemis(inProgress: ArtemisState, block: (precedingState: ArtemisState) -> ArtemisState) {
val runnable = {
synchronized(artemis) {
synchronized(artemis!!) {
try {
val precedingState = artemisState
artemisState.pending?.cancel(false)
@ -231,7 +210,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
}
ArtemisState.STOPPING
}
bridgeMetricsService?.bridgeDisconnected(targets, legalNames)
bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
connectedSubscription?.unsubscribe()
connectedSubscription = null
// Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first.
@ -243,7 +222,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
if (connected) {
logInfoWithMDC("Bridge Connected")
bridgeMetricsService?.bridgeConnected(targets, legalNames)
bridgeMetricsService?.bridgeConnected(targets, allowedRemoteLegalNames)
if (bridgeConnectionTTLSeconds > 0) {
// AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval
amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS,
@ -253,7 +232,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
}
}
artemis(ArtemisState.STARTING) {
val startedArtemis = artemis.started
val startedArtemis = artemis!!.started
if (startedArtemis == null) {
logInfoWithMDC("Bridge Connected but Artemis is disconnected")
ArtemisState.STOPPED
@ -286,7 +265,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
logInfoWithMDC("Bridge Disconnected")
amqpRestartEvent?.cancel(false)
if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) {
bridgeMetricsService?.bridgeDisconnected(targets, legalNames)
bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
}
artemis(ArtemisState.STOPPING) { precedingState: ArtemisState ->
logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected")
@ -418,10 +397,10 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
properties[key] = value
}
}
logDebugWithMDC { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
logDebugWithMDC { "Bridged Send to ${allowedRemoteLegalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
val peerInbox = translateLocalQueueToInboxAddress(queueName)
val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox,
legalNames.first().toString(),
allowedRemoteLegalNames.first().toString(),
properties)
sendableMessage.onComplete.then {
logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" }
@ -457,6 +436,29 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
}
}
@Suppress("ClassNaming")
private sealed class ArtemisState {
object STARTING : ArtemisState()
data class STARTED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
object CHECKING : ArtemisState()
object RESTARTED : ArtemisState()
object RECEIVING : ArtemisState()
object AMQP_STOPPED : ArtemisState()
object AMQP_STARTING : ArtemisState()
object AMQP_STARTED : ArtemisState()
object AMQP_RESTARTED : ArtemisState()
object STOPPING : ArtemisState()
object STOPPED : ArtemisState()
data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
open val pending: ScheduledFuture<Unit>? = null
override fun toString(): String = javaClass.simpleName
}
override fun deployBridge(sourceX500Name: String, queueName: String, targets: List<NetworkHostAndPort>, legalNames: Set<CordaX500Name>) {
lock.withLock {
val bridges = queueNamesToBridgesMap.getOrPut(queueName) { mutableListOf() }
@ -467,8 +469,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
}
val newAMQPConfig = with(amqpConfig) { AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize,
revocationConfig, useOpenSsl, enableSNI, sourceX500Name, trace, sslHandshakeTimeout) }
val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig, sharedEventLoopGroup!!, artemis!!,
bridgeMetricsService, bridgeConnectionTTLSeconds)
val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig)
bridges += newBridge
bridgeMetricsService?.bridgeCreated(targets, legalNames)
newBridge
@ -486,7 +487,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
queueNamesToBridgesMap.remove(queueName)
}
bridge.stop()
bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.legalNames)
bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.allowedRemoteLegalNames)
}
}
}
@ -497,15 +498,16 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
// queueNamesToBridgesMap returns a mutable list, .toList converts it to a immutable list so it won't be changed by the [destroyBridge] method.
val bridges = queueNamesToBridgesMap[queueName]?.toList()
destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList())
bridges?.map {
it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.legalNames.toList(), serviceAddress = false)
}?.toMap() ?: emptyMap()
bridges?.associate {
it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.allowedRemoteLegalNames.toList(), serviceAddress = false)
} ?: emptyMap()
}
}
override fun start() {
sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS, DefaultThreadFactory("AMQPBridge", Thread.MAX_PRIORITY))
val artemis = artemisMessageClientFactory()
sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS, DefaultThreadFactory("NettyBridge", Thread.MAX_PRIORITY))
sslDelegatedTaskExecutor = sslDelegatedTaskExecutor("NettyBridge")
val artemis = artemisMessageClientFactory("ArtemisBridge")
this.artemis = artemis
artemis.start()
}
@ -522,6 +524,8 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
sharedEventLoopGroup = null
queueNamesToBridgesMap.clear()
artemis?.stop()
sslDelegatedTaskExecutor?.shutdown()
sslDelegatedTaskExecutor = null
}
}
}

View File

@ -35,7 +35,7 @@ class BridgeControlListener(private val keyStore: CertificateStore,
maxMessageSize: Int,
revocationConfig: RevocationConfig,
enableSNI: Boolean,
private val artemisMessageClientFactory: () -> ArtemisSessionProvider,
private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider,
bridgeMetricsService: BridgeMetricsService? = null,
trace: Boolean = false,
sslHandshakeTimeout: Duration? = null,
@ -80,7 +80,7 @@ class BridgeControlListener(private val keyStore: CertificateStore,
bridgeNotifyQueue = "$BRIDGE_NOTIFY.$queueDisambiguityId"
bridgeManager.start()
val artemis = artemisMessageClientFactory()
val artemis = artemisMessageClientFactory("BridgeControl")
this.artemis = artemis
artemis.start()
val artemisClient = artemis.started!!

View File

@ -37,7 +37,7 @@ class LoopbackBridgeManager(keyStore: CertificateStore,
maxMessageSize: Int,
revocationConfig: RevocationConfig,
enableSNI: Boolean,
private val artemisMessageClientFactory: () -> ArtemisSessionProvider,
private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider,
private val bridgeMetricsService: BridgeMetricsService? = null,
private val isLocalInbox: (String) -> Boolean,
trace: Boolean,
@ -204,7 +204,7 @@ class LoopbackBridgeManager(keyStore: CertificateStore,
override fun start() {
super.start()
val artemis = artemisMessageClientFactory()
val artemis = artemisMessageClientFactory("LoopbackBridge")
this.artemis = artemis
artemis.start()
}

View File

@ -5,16 +5,37 @@ package net.corda.nodeapi.internal.crypto
import net.corda.core.CordaOID
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.newSecureRandom
import net.corda.core.internal.*
import net.corda.core.internal.CertRole
import net.corda.core.internal.SignedDataWithCert
import net.corda.core.internal.reader
import net.corda.core.internal.signWithCert
import net.corda.core.internal.uncheckedCast
import net.corda.core.internal.validate
import net.corda.core.internal.writer
import net.corda.core.utilities.days
import net.corda.core.utilities.millis
import net.corda.core.utilities.toHex
import net.corda.nodeapi.internal.protonwrapper.netty.distributionPointsToString
import org.bouncycastle.asn1.*
import org.bouncycastle.asn1.ASN1EncodableVector
import org.bouncycastle.asn1.ASN1ObjectIdentifier
import org.bouncycastle.asn1.ASN1Sequence
import org.bouncycastle.asn1.DERSequence
import org.bouncycastle.asn1.DERUTF8String
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x509.*
import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.DistributionPoint
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.KeyPurposeId
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.asn1.x509.NameConstraints
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.X509v3CertificateBuilder
import org.bouncycastle.cert.bc.BcX509ExtensionUtils
@ -32,8 +53,13 @@ import java.nio.file.Path
import java.security.KeyPair
import java.security.PublicKey
import java.security.SignatureException
import java.security.cert.*
import java.security.cert.CertPath
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.CertificateFactory
import java.security.cert.TrustAnchor
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.time.Duration
import java.time.Instant
import java.time.temporal.ChronoUnit
@ -359,7 +385,7 @@ object X509Utilities {
private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) {
if (crlDistPoint != null) {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint)))
val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier))
val crlIssuerGeneralNames = crlIssuer?.let {
GeneralNames(GeneralName(crlIssuer))
}
@ -379,6 +405,8 @@ object X509Utilities {
bytes[0] = bytes[0].and(0x3F).or(0x40)
return BigInteger(bytes)
}
fun toGeneralNames(string: String, tag: Int = GeneralName.directoryName): GeneralNames = GeneralNames(GeneralName(tag, string))
}
// Assuming cert type to role is 1:1

View File

@ -27,16 +27,17 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME
import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import rx.Observable
import rx.subjects.PublishSubject
import java.lang.Long.min
import java.net.InetSocketAddress
import java.security.cert.CertPathValidatorException
import java.util.concurrent.Executor
import java.util.concurrent.ExecutorService
import java.util.concurrent.ThreadPoolExecutor
import java.time.Duration
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock
enum class ProxyVersion {
@ -63,8 +64,8 @@ data class ProxyConfig(val version: ProxyVersion, val proxyAddress: NetworkHostA
class AMQPClient(private val targets: List<NetworkHostAndPort>,
val allowedRemoteLegalNames: Set<CordaX500Name>,
private val configuration: AMQPConfiguration,
private val sharedThreadPool: EventLoopGroup? = null,
private val threadPoolName: String = "AMQPClient") : AutoCloseable {
private val nettyThreading: NettyThreading = NettyThreading.NonShared("AMQPClient"),
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON) : AutoCloseable {
companion object {
init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
@ -84,14 +85,12 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
private val lock = ReentrantLock()
@Volatile
private var started: Boolean = false
private var workerGroup: EventLoopGroup? = null
@Volatile
private var clientChannel: Channel? = null
// Offset into the list of targets, so that we can implement round-robin reconnect logic.
private var targetIndex = 0
private var currentTarget: NetworkHostAndPort = targets.first()
private var retryInterval = MIN_RETRY_INTERVAL
private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker()
private val handshakeFailureRetryTargets = mutableSetOf<NetworkHostAndPort>()
private var retryingHandshakeFailures = false
private var retryOffset = 0
@ -172,7 +171,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
log.info("Failed to connect to $currentTarget", future.cause())
if (started) {
workerGroup?.schedule({
nettyThreading.eventLoopGroup.schedule({
nextTarget()
restart()
}, retryInterval, TimeUnit.MILLISECONDS)
@ -191,7 +190,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
clientChannel = null
if (started && !amqpActive) {
log.debug { "Scheduling restart of $currentTarget (AMQP inactive)" }
workerGroup?.schedule({
nettyThreading.eventLoopGroup.schedule({
nextTarget()
restart()
}, retryInterval, TimeUnit.MILLISECONDS)
@ -199,17 +198,16 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
}
private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration
@Volatile
private lateinit var amqpChannelHandler: AMQPChannelHandler
init {
keyManagerFactory.init(conf.keyStore)
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker))
}
@Suppress("ComplexMethod")
override fun initChannel(ch: SocketChannel) {
val pipeline = ch.pipeline()
@ -249,9 +247,22 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration)
val target = parent.currentTarget
val handler = if (parent.configuration.useOpenSsl) {
createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc())
createClientOpenSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
ch.alloc(),
parent.nettyThreading.sslDelegatedTaskExecutor
)
} else {
createClientSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory)
createClientSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
parent.nettyThreading.sslDelegatedTaskExecutor
)
}
handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis()
pipeline.addLast("sslHandler", handler)
@ -292,7 +303,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
if (started && amqpActive) {
log.debug { "Scheduling restart of $currentTarget (AMQP active)" }
workerGroup?.schedule({
nettyThreading.eventLoopGroup.schedule({
nextTarget()
restart()
}, retryInterval, TimeUnit.MILLISECONDS)
@ -309,7 +320,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
return
}
log.info("Connect to: $currentTarget")
workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY))
(nettyThreading as? NettyThreading.NonShared)?.start()
started = true
restart()
}
@ -321,7 +332,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
}
val bootstrap = Bootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
bootstrap.group(workerGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this))
bootstrap.group(nettyThreading.eventLoopGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this))
// Delegate DNS Resolution to the proxy side, if we are using proxy.
if (configuration.proxyConfig != null) {
bootstrap.resolver(NoopAddressResolverGroup.INSTANCE)
@ -335,14 +346,12 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
lock.withLock {
log.info("Stopping connection to: $currentTarget, Local address: $localAddressString")
started = false
if (sharedThreadPool == null) {
workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync()
if (nettyThreading is NettyThreading.NonShared) {
nettyThreading.stop()
} else {
clientChannel?.close()?.sync()
}
clientChannel = null
workerGroup = null
log.info("Stopped connection to $currentTarget")
}
}
@ -384,5 +393,35 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
val onConnection: Observable<ConnectionChange>
get() = _onConnection
val softFailExceptions: List<CertPathValidatorException> get() = revocationChecker.softFailExceptions
}
sealed class NettyThreading {
abstract val eventLoopGroup: EventLoopGroup
abstract val sslDelegatedTaskExecutor: Executor
class Shared(override val eventLoopGroup: EventLoopGroup,
override val sslDelegatedTaskExecutor: ExecutorService = sslDelegatedTaskExecutor("AMQPClient")) : NettyThreading()
class NonShared(val threadPoolName: String) : NettyThreading() {
private var _eventLoopGroup: NioEventLoopGroup? = null
override val eventLoopGroup: EventLoopGroup get() = checkNotNull(_eventLoopGroup)
private var _sslDelegatedTaskExecutor: ThreadPoolExecutor? = null
override val sslDelegatedTaskExecutor: ExecutorService get() = checkNotNull(_sslDelegatedTaskExecutor)
fun start() {
check(_eventLoopGroup == null)
check(_sslDelegatedTaskExecutor == null)
_eventLoopGroup = NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY))
_sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
}
fun stop() {
eventLoopGroup.shutdownGracefully()
eventLoopGroup.terminationFuture().sync()
sslDelegatedTaskExecutor.shutdown()
_eventLoopGroup = null
_sslDelegatedTaskExecutor = null
}
}
}
}

View File

@ -21,16 +21,15 @@ import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import org.apache.qpid.proton.engine.Delivery
import rx.Observable
import rx.subjects.PublishSubject
import java.net.BindException
import java.net.InetSocketAddress
import java.security.cert.CertPathValidatorException
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock
/**
@ -39,37 +38,34 @@ import kotlin.concurrent.withLock
class AMQPServer(val hostName: String,
val port: Int,
private val configuration: AMQPConfiguration,
private val threadPoolName: String = "AMQPServer") : AutoCloseable {
private val threadPoolName: String = "AMQPServer",
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON,
private val remotingThreads: Int? = null) : AutoCloseable {
companion object {
init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
}
private const val CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME = "net.corda.nodeapi.amqpserver.NumServerThreads"
private val log = contextLogger()
private val NUM_SERVER_THREADS = Integer.getInteger(CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME, 4)
private val DEFAULT_REMOTING_THREADS = Integer.getInteger("net.corda.nodeapi.amqpserver.NumServerThreads", 4)
}
private val lock = ReentrantLock()
@Volatile
private var stopping: Boolean = false
private var bossGroup: EventLoopGroup? = null
private var workerGroup: EventLoopGroup? = null
private var serverChannel: Channel? = null
private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker()
private var sslDelegatedTaskExecutor: ExecutorService? = null
private val clientChannels = ConcurrentHashMap<InetSocketAddress, SocketChannel>()
private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration
init {
keyManagerFactory.init(conf.keyStore.value.internal, conf.keyStore.entryPassword.toCharArray())
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker))
}
override fun initChannel(ch: SocketChannel) {
val amqpConfiguration = parent.configuration
val pipeline = ch.pipeline()
@ -116,11 +112,12 @@ class AMQPServer(val hostName: String,
Pair(createServerSNIOpenSniHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap)
} else {
val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig)
val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor)
val handler = if (amqpConfig.useOpenSsl) {
createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc())
createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc(), delegatedTaskExecutor)
} else {
// For javaSSL, SNI matching is handled at key manager level.
createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory)
createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory, delegatedTaskExecutor)
}
handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout.toMillis()
Pair(handler, mapOf(DEFAULT to keyManagerFactory))
@ -132,8 +129,13 @@ class AMQPServer(val hostName: String,
lock.withLock {
stop()
sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("$threadPoolName-boss", Thread.MAX_PRIORITY))
workerGroup = NioEventLoopGroup(NUM_SERVER_THREADS, DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY))
workerGroup = NioEventLoopGroup(
remotingThreads ?: DEFAULT_REMOTING_THREADS,
DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY)
)
val server = ServerBootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
@ -154,22 +156,19 @@ class AMQPServer(val hostName: String,
fun stop() {
lock.withLock {
try {
stopping = true
serverChannel?.apply { close() }
serverChannel = null
serverChannel?.close()
serverChannel = null
workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync()
workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync()
workerGroup = null
bossGroup?.shutdownGracefully()
bossGroup?.terminationFuture()?.sync()
bossGroup?.shutdownGracefully()
bossGroup?.terminationFuture()?.sync()
bossGroup = null
workerGroup = null
bossGroup = null
} finally {
stopping = false
}
sslDelegatedTaskExecutor?.shutdown()
sslDelegatedTaskExecutor = null
}
}
@ -226,6 +225,4 @@ class AMQPServer(val hostName: String,
private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized()
val onConnection: Observable<ConnectionChange>
get() = _onConnection
val softFailExceptions: List<CertPathValidatorException> get() = revocationChecker.softFailExceptions
}

View File

@ -31,4 +31,6 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() {
override fun getSoftFailExceptions(): List<CertPathValidatorException> {
return Collections.emptyList()
}
override fun clone(): AllowAllRevocationChecker = this
}

View File

@ -3,9 +3,6 @@ package net.corda.nodeapi.internal.protonwrapper.netty
import com.typesafe.config.Config
import net.corda.nodeapi.internal.config.ConfigParser
import net.corda.nodeapi.internal.config.CustomConfigParser
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import java.security.cert.PKIXRevocationChecker
/**
* Data structure for controlling the way how Certificate Revocation Lists are handled.
@ -45,18 +42,6 @@ interface RevocationConfig {
* Optional [CrlSource] which only makes sense with `mode` = `EXTERNAL_SOURCE`
*/
val externalCrlSource: CrlSource?
fun createPKIXRevocationChecker(): PKIXRevocationChecker {
return when (mode) {
Mode.OFF -> AllowAllRevocationChecker
Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(externalCrlSource) { "externalCrlSource must be specfied for EXTERNAL_SOURCE" }
CordaRevocationChecker(externalCrlSource, softFail = true)
}
Mode.SOFT_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = true)
Mode.HARD_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = false)
}
}
}
/**

View File

@ -1,3 +1,5 @@
@file:Suppress("ComplexMethod", "LongParameterList")
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.buffer.ByteBufAllocator
@ -18,6 +20,8 @@ import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.namedThreadPoolExecutor
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import org.bouncycastle.asn1.ASN1InputStream
import org.bouncycastle.asn1.ASN1Primitive
import org.bouncycastle.asn1.ASN1IA5String
@ -34,10 +38,10 @@ import java.net.URI
import java.security.KeyStore
import java.security.cert.CertificateException
import java.security.cert.PKIXBuilderParameters
import java.security.cert.PKIXRevocationChecker
import java.security.cert.X509CertSelector
import java.security.cert.X509Certificate
import java.util.concurrent.Executor
import java.util.concurrent.ThreadPoolExecutor
import javax.net.ssl.CertPathTrustManagerParameters
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName
@ -46,7 +50,6 @@ import javax.net.ssl.SSLEngine
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509ExtendedTrustManager
import javax.security.auth.x500.X500Principal
import kotlin.system.measureTimeMillis
private const val HOSTNAME_FORMAT = "%s.corda.net"
internal const val DEFAULT = "default"
@ -58,7 +61,6 @@ internal val logger = LoggerFactory.getLogger("net.corda.nodeapi.internal.proton
/**
* Returns all the CRL distribution points in the certificate as [URI]s along with the CRL issuer names, if any.
*/
@Suppress("ComplexMethod")
fun X509Certificate.distributionPoints(): Map<URI, List<X500Principal>?> {
logger.debug { "Checking CRLDPs for $subjectX500Principal" }
@ -117,6 +119,14 @@ fun certPathToString(certPath: Array<out X509Certificate>?): String {
return certPath.joinToString(System.lineSeparator()) { " ${it.toSimpleString()}" }
}
/**
* Create an executor for processing SSL handshake tasks asynchronously (see [SSLEngine.getDelegatedTask]). The max number of threads is 3,
* which is the typical number of CRLs expected in a Corda TLS cert path. The executor needs to be passed to the [SslHandler] constructor.
*/
fun sslDelegatedTaskExecutor(parentPoolName: String): ThreadPoolExecutor {
return namedThreadPoolExecutor(maxPoolSize = 3, poolName = "$parentPoolName-ssltask")
}
@VisibleForTesting
class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() {
companion object {
@ -179,32 +189,11 @@ class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509Ex
}
private object LoggingImmediateExecutor : Executor {
override fun execute(command: Runnable) {
val log = LoggerFactory.getLogger(javaClass)
@Suppress("TooGenericExceptionCaught", "MagicNumber") // log and rethrow all exceptions
try {
val commandName = command::class.qualifiedName?.let { "[$it]" } ?: ""
log.debug("Entering SSL command $commandName")
val elapsedTime = measureTimeMillis { command.run() }
log.debug("Exiting SSL command $elapsedTime millis")
if (elapsedTime > 100) {
log.info("Command: $commandName took $elapsedTime millis to execute")
}
}
catch (ex: Exception) {
log.error("Caught exception in SSL handler executor", ex)
throw ex
}
}
}
internal fun createClientSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine(target.host, target.port)
sslEngine.useClientMode = true
@ -216,14 +205,15 @@ internal fun createClientSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters
}
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler {
alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build()
val sslEngine = sslContext.newEngine(alloc, target.host, target.port)
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
@ -233,12 +223,13 @@ internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters
}
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
internal fun createServerSslHandler(keyStore: CertificateStore,
keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler {
trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine()
sslEngine.useClientMode = false
@ -249,39 +240,29 @@ internal fun createServerSslHandler(keyStore: CertificateStore,
val sslParameters = sslEngine.sslParameters
sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore))
sslEngine.sslParameters = sslParameters
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler {
alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build()
val sslEngine = sslContext.newEngine(alloc)
sslEngine.useClientMode = false
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
return SslHandler(sslEngine, false, delegateTaskExecutor)
}
fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SSLContext {
fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory?): SSLContext {
val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java)
.map { LoggingTrustManagerWrapper(it) }.toTypedArray()
sslContext.init(keyManagers, trustManagers, newSecureRandom())
val trustManagers = trustManagerFactory
?.trustManagers
?.map { if (it is X509ExtendedTrustManager) LoggingTrustManagerWrapper(it) else it }
?.toTypedArray()
sslContext.init(keyManagerFactory.keyManagers, trustManagers, newSecureRandom())
return sslContext
}
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore,
revocationConfig: RevocationConfig): CertPathTrustManagerParameters {
return initialiseTrustStoreAndEnableCrlChecking(trustStore, revocationConfig.createPKIXRevocationChecker())
}
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore,
revocationChecker: PKIXRevocationChecker): CertPathTrustManagerParameters {
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
return CertPathTrustManagerParameters(pkixParams)
}
/**
* Creates a special SNI handler used only when openSSL is used for AMQPServer
*/
@ -296,14 +277,13 @@ internal fun createServerSNIOpenSniHandler(keyManagerFactoriesMap: Map<String, K
return SniHandler(mapping.build())
}
@Suppress("SpreadOperator")
private fun getServerSslContextBuilder(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslContextBuilder {
return SslContextBuilder.forServer(keyManagerFactory)
.sslProvider(SslProvider.OPENSSL)
.trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory))
.clientAuth(ClientAuth.REQUIRE)
.ciphers(ArtemisTcpTransport.CIPHER_SUITES)
.protocols(*ArtemisTcpTransport.TLS_VERSIONS.toTypedArray())
.protocols(ArtemisTcpTransport.TLS_VERSIONS)
}
internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKeyManagerFactoryWrapper> {
@ -327,7 +307,38 @@ internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKe
// 2nd parameter `password` - the password for recovering keys in the KeyStore
fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.entryPassword.toCharArray())
fun TrustManagerFactory.init(trustStore: CertificateStore) = init(trustStore.value.internal)
fun keyManagerFactory(keyStore: CertificateStore): KeyManagerFactory {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
return keyManagerFactory
}
fun trustManagerFactory(trustStore: CertificateStore): TrustManagerFactory {
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore.value.internal)
return trustManagerFactory
}
fun trustManagerFactoryWithRevocation(trustStore: CertificateStore,
revocationConfig: RevocationConfig,
crlSource: CrlSource): TrustManagerFactory {
val revocationChecker = when (revocationConfig.mode) {
RevocationConfig.Mode.OFF -> AllowAllRevocationChecker
RevocationConfig.Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(revocationConfig.externalCrlSource) {
"externalCrlSource must be specfied for EXTERNAL_SOURCE"
}
CordaRevocationChecker(externalCrlSource, softFail = true)
}
RevocationConfig.Mode.SOFT_FAIL -> CordaRevocationChecker(crlSource, softFail = true)
RevocationConfig.Mode.HARD_FAIL -> CordaRevocationChecker(crlSource, softFail = false)
}
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(CertPathTrustManagerParameters(pkixParams))
return trustManagerFactory
}
/**
* Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target

View File

@ -5,6 +5,9 @@ import com.github.benmanes.caffeine.cache.LoadingCache
import net.corda.core.internal.readFully
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
@ -12,60 +15,71 @@ import net.corda.nodeapi.internal.protonwrapper.netty.distributionPoints
import java.net.URI
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit
import java.time.Duration
import javax.security.auth.x500.X500Principal
/**
* [CrlSource] which downloads CRLs from the distribution points in the X509 certificate.
* [CrlSource] which downloads CRLs from the distribution points in the X509 certificate and caches them.
*/
@Suppress("TooGenericExceptionCaught")
class CertDistPointCrlSource : CrlSource {
class CertDistPointCrlSource(cacheSize: Long = DEFAULT_CACHE_SIZE,
cacheExpiry: Duration = DEFAULT_CACHE_EXPIRY,
private val connectTimeout: Duration = DEFAULT_CONNECT_TIMEOUT,
private val readTimeout: Duration = DEFAULT_READ_TIMEOUT) : CrlSource {
companion object {
private val logger = contextLogger()
// The default SSL handshake timeout is 60s (DEFAULT_SSL_HANDSHAKE_TIMEOUT). Considering there are 3 CRLs endpoints to check in a
// node handshake, we want to keep the total timeout within that.
private const val DEFAULT_CONNECT_TIMEOUT = 9_000
private const val DEFAULT_READ_TIMEOUT = 9_000
private val DEFAULT_CONNECT_TIMEOUT = 9.seconds
private val DEFAULT_READ_TIMEOUT = 9.seconds
private const val DEFAULT_CACHE_SIZE = 185L // Same default as the JDK (URICertStore)
private const val DEFAULT_CACHE_EXPIRY = 5 * 60 * 1000L
private val DEFAULT_CACHE_EXPIRY = 5.minutes
private val cache: LoadingCache<URI, X509CRL> = Caffeine.newBuilder()
.maximumSize(java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE))
.expireAfterWrite(java.lang.Long.getLong("net.corda.dpcrl.cache.expiry", DEFAULT_CACHE_EXPIRY), TimeUnit.MILLISECONDS)
.build(::retrieveCRL)
val SINGLETON = CertDistPointCrlSource(
cacheSize = java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE),
cacheExpiry = java.lang.Long.getLong("net.corda.dpcrl.cache.expiry")?.let(Duration::ofMillis) ?: DEFAULT_CACHE_EXPIRY,
connectTimeout = java.lang.Long.getLong("net.corda.dpcrl.connect.timeout")?.let(Duration::ofMillis) ?: DEFAULT_CONNECT_TIMEOUT,
readTimeout = java.lang.Long.getLong("net.corda.dpcrl.read.timeout")?.let(Duration::ofMillis) ?: DEFAULT_READ_TIMEOUT
)
}
private val connectTimeout = Integer.getInteger("net.corda.dpcrl.connect.timeout", DEFAULT_CONNECT_TIMEOUT)
private val readTimeout = Integer.getInteger("net.corda.dpcrl.read.timeout", DEFAULT_READ_TIMEOUT)
private val cache: LoadingCache<URI, X509CRL> = Caffeine.newBuilder()
.maximumSize(cacheSize)
.expireAfterWrite(cacheExpiry)
.build(::retrieveCRL)
private fun retrieveCRL(uri: URI): X509CRL {
val start = System.currentTimeMillis()
val bytes = try {
val conn = uri.toURL().openConnection()
conn.connectTimeout = connectTimeout
conn.readTimeout = readTimeout
// Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes
// in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException
// into CRLException and drops the cause chain.
conn.getInputStream().readFully()
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e)
}
throw e
private fun retrieveCRL(uri: URI): X509CRL {
val start = System.currentTimeMillis()
val bytes = try {
val conn = uri.toURL().openConnection()
conn.connectTimeout = connectTimeout.toMillis().toInt()
conn.readTimeout = readTimeout.toMillis().toInt()
// Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes
// in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException
// into CRLException and drops the cause chain.
conn.getInputStream().readFully()
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e)
}
val duration = System.currentTimeMillis() - start
val crl = try {
X509CertificateFactory().generateCRL(bytes.inputStream())
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Invalid CRL from $uri (${duration}ms)", e)
}
throw e
}
logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" }
return crl
throw e
}
val duration = System.currentTimeMillis() - start
val crl = try {
X509CertificateFactory().generateCRL(bytes.inputStream())
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Invalid CRL from $uri (${duration}ms)", e)
}
throw e
}
logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" }
return crl
}
fun clearCache() {
cache.invalidateAll()
}
override fun fetch(certificate: X509Certificate): Set<X509CRL> {

View File

@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.assertj.core.api.Assertions
import org.junit.Rule
import org.junit.Test
@ -161,11 +162,9 @@ class TlsDiffAlgorithmsTest(private val serverAlgo: String, private val clientAl
private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext {
return SSLContext.getInstance("TLS").apply {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagerFactory = keyManagerFactory(keyStore)
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustMgrFactory = trustManagerFactory(trustStore)
val trustManagers = trustMgrFactory.trustManagers
init(keyManagers, trustManagers, newSecureRandom())
}

View File

@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.assertj.core.api.Assertions
import org.junit.Ignore
import org.junit.Rule
@ -18,7 +19,6 @@ import java.io.IOException
import java.net.InetAddress
import java.net.InetSocketAddress
import javax.net.ssl.*
import javax.net.ssl.SNIHostName
import kotlin.concurrent.thread
import kotlin.test.assertEquals
import kotlin.test.assertFalse
@ -209,11 +209,9 @@ class TlsDiffProtocolsTest(private val serverAlgo: String, private val clientAlg
private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext {
return SSLContext.getInstance("TLS").apply {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagerFactory = keyManagerFactory(keyStore)
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustMgrFactory = trustManagerFactory(trustStore)
val trustManagers = trustMgrFactory.trustManagers
init(keyManagers, trustManagers, newSecureRandom())
}

View File

@ -1,5 +1,6 @@
package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.util.concurrent.ImmediateExecutor
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.NetworkHostAndPort
@ -8,10 +9,9 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
import net.corda.testing.internal.fixedCrlSource
import org.junit.Test
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName
import javax.net.ssl.TrustManagerFactory
import kotlin.test.assertEquals
class SSLHelperTest {
@ -20,15 +20,21 @@ class SSLHelperTest {
val legalName = CordaX500Name("Test", "London", "GB")
val sslConfig = configureTestSSL(legalName)
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val keyStore = sslConfig.keyStore
keyManagerFactory.init(CertificateStore.fromFile(keyStore.path, keyStore.storePassword, keyStore.entryPassword, false))
val trustStore = sslConfig.trustStore
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(CertificateStore.fromFile(trustStore.path, trustStore.storePassword, trustStore.entryPassword, false), RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL)))
val trustManagerFactory = trustManagerFactoryWithRevocation(
sslConfig.trustStore.get(),
RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL),
fixedCrlSource(emptySet())
)
val sslHandler = createClientSslHandler(NetworkHostAndPort("localhost", 1234), setOf(legalName), keyManagerFactory, trustManagerFactory)
val sslHandler = createClientSslHandler(
NetworkHostAndPort("localhost", 1234),
setOf(legalName),
keyManagerFactory,
trustManagerFactory,
ImmediateExecutor.INSTANCE
)
val legalNameHash = SecureHash.sha256(legalName.toString()).toString().take(32).toLowerCase()
// These hardcoded values must not be changed, something is broken if you have to change these hardcoded values.

View File

@ -2,15 +2,13 @@ package net.corda.nodeapi.internal.revocation
import net.corda.core.crypto.Crypto
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.testing.core.ALICE_NAME
import net.corda.nodeapi.internal.DEV_INTERMEDIATE_CA
import net.corda.testing.node.internal.network.CrlServer
import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.math.BigInteger
class CertDistPointCrlSourceTest {
private lateinit var crlServer: CrlServer
@ -39,13 +37,14 @@ class CertDistPointCrlSourceTest {
assertThat(single().revokedCertificates).isNull()
}
val nodeCaCert = crlServer.replaceNodeCertDistPoint(createDevNodeCa(crlServer.intermediateCa, ALICE_NAME).certificate)
crlSource.clearCache()
crlServer.revokedNodeCerts += listOf(BigInteger.ONE, BigInteger.TEN)
with(crlSource.fetch(nodeCaCert)) { // Use a different cert to avoid the cache
crlServer.revokedIntermediateCerts += DEV_INTERMEDIATE_CA.certificate
with(crlSource.fetch(crlServer.intermediateCa.certificate)) {
assertThat(size).isEqualTo(1)
val revokedCertificates = single().revokedCertificates
assertThat(revokedCertificates.map { it.serialNumber }).containsExactlyInAnyOrder(BigInteger.ONE, BigInteger.TEN)
// This also tests clearCache() works.
assertThat(revokedCertificates.map { it.serialNumber }).containsExactly(DEV_INTERMEDIATE_CA.certificate.serialNumber)
}
}
}

View File

@ -5,7 +5,7 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
import net.corda.testing.internal.fixedCrlSource
import org.bouncycastle.jcajce.provider.asymmetric.x509.CertificateFactory
import org.junit.Test
import java.math.BigInteger
@ -41,10 +41,8 @@ class CordaRevocationCheckerTest {
val resourceAsStream = javaClass.getResourceAsStream("/net/corda/nodeapi/internal/protonwrapper/netty/doorman.crl")
val crl = CertificateFactory().engineGenerateCRL(resourceAsStream) as X509CRL
val crlSource = object : CrlSource {
override fun fetch(certificate: X509Certificate): Set<X509CRL> = setOf(crl)
}
val checker = CordaRevocationChecker(crlSource,
val checker = CordaRevocationChecker(
crlSource = fixedCrlSource(setOf(crl)),
softFail = true,
dateSource = { Date.from(date.atStartOfDay().toInstant(ZoneOffset.UTC)) }
)

View File

@ -1,20 +1,16 @@
package net.corda.node.internal.artemis
package net.corda.nodeapi.internal.revocation
import net.corda.core.crypto.Crypto
import net.corda.core.utilities.days
import net.corda.node.internal.artemis.CertificateChainCheckPolicy.RevocationCheck
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509KeyStore
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.testing.core.createCRL
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x509.CRLReason
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.ExtensionsGenerator
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.IssuingDistributionPoint
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.junit.Before
import org.junit.Rule
import org.junit.Test
@ -22,15 +18,18 @@ import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
import java.io.File
import java.security.KeyPair
import java.security.KeyStore
import java.security.PrivateKey
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import java.util.*
import javax.net.ssl.X509TrustManager
import javax.security.auth.x500.X500Principal
import kotlin.test.assertFails
import kotlin.test.assertFailsWith
@RunWith(Parameterized::class)
class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
class RevocationTest(private val revocationMode: RevocationConfig.Mode) {
companion object {
@JvmStatic
@Parameterized.Parameters(name = "revocationMode = {0}")
@ -45,8 +44,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
private lateinit var doormanCRL: File
private lateinit var tlsCRL: File
private val keyStore = KeyStore.getInstance("JKS")
private val trustStore = KeyStore.getInstance("JKS")
private lateinit var trustManager: X509TrustManager
private val rootKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private val tlsCRLIssuerKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
@ -61,7 +59,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
private lateinit var tlsCert: X509Certificate
private val chain
get() = listOf(tlsCert, nodeCACert, doormanCert, rootCert).toTypedArray()
get() = arrayOf(tlsCert, nodeCACert, doormanCert, rootCert)
@Before
fun before() {
@ -72,10 +70,18 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
rootCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=root"), rootKeyPair)
tlsCRLIssuerCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=issuer"), tlsCRLIssuerKeyPair)
val trustStore = KeyStore.getInstance("JKS")
trustStore.load(null, null)
trustStore.setCertificateEntry("cordatlscrlsigner", tlsCRLIssuerCert)
trustStore.setCertificateEntry("cordarootca", rootCert)
val trustManagerFactory = trustManagerFactoryWithRevocation(
CertificateStore.of(X509KeyStore(trustStore, "pass"), "pass", "pass"),
RevocationConfigImpl(revocationMode),
CertDistPointCrlSource()
)
trustManager = trustManagerFactory.trustManagers.single() as X509TrustManager
doormanCert = X509Utilities.createCertificate(
CertificateType.INTERMEDIATE_CA, rootCert, rootKeyPair, X500Principal("CN=doorman"), doormanKeyPair.public,
crlDistPoint = rootCRL.toURI().toString()
@ -89,43 +95,34 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
)
rootCRL.createCRL(rootCert, rootKeyPair.private, false)
doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false)
tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true)
rootCRL.writeCRL(rootCert, rootKeyPair.private, false)
doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false)
tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true)
}
private fun File.createCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) {
val builder = JcaX509v2CRLBuilder(certificate.subjectX500Principal, Date())
builder.setNextUpdate(Date.from(Date().toInstant() + 7.days))
builder.addExtension(Extension.issuingDistributionPoint, true, IssuingDistributionPoint(null, indirect, false))
revoked.forEach {
val extensionsGenerator = ExtensionsGenerator()
extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(CRLReason.keyCompromise))
// Certificate issuer is required for indirect CRL
val certificateIssuerName = X500Name.getInstance(it.issuerX500Principal.encoded)
extensionsGenerator.addExtension(Extension.certificateIssuer, true, GeneralNames(GeneralName(certificateIssuerName)))
builder.addCRLEntry(it.serialNumber, Date(), extensionsGenerator.generate())
}
val holder = builder.build(JcaContentSignerBuilder("SHA256withECDSA").setProvider(Crypto.findProvider("BC")).build(privateKey))
outputStream().use { it.write(holder.encoded) }
private fun File.writeCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) {
val crl = createCRL(
CertificateAndKeyPair(certificate, KeyPair(certificate.publicKey, privateKey)),
revoked.asList(),
indirect = indirect
)
writeBytes(crl.encoded)
}
private fun assertFailsFor(vararg modes: RevocationConfig.Mode, block: () -> Unit) {
if (revocationMode in modes) assertFails(block) else block()
private fun assertFailsFor(vararg modes: RevocationConfig.Mode) {
if (revocationMode in modes) assertFailsWith(CertificateException::class, ::doRevocationCheck) else doRevocationCheck()
}
@Test(timeout = 300_000)
fun `ok with empty CRLs`() {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
doRevocationCheck()
}
@Test(timeout = 300_000)
fun `soft fail with revoked TLS certificate`() {
tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert)
tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert)
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
@ -136,9 +133,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
@ -148,9 +143,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name("CN=unknown")
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
@ -160,9 +153,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = tlsCRL.toURI().toString()
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
@ -172,18 +163,16 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=other"), otherKeyPair.public,
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
)
tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert)
tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
doRevocationCheck()
}
@Test(timeout = 300_000)
fun `soft fail with revoked node CA certificate`() {
doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, nodeCACert)
doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, nodeCACert)
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
@ -193,9 +182,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/doorman"
)
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
}
@Test(timeout = 300_000)
@ -205,8 +192,12 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=other"), otherKeyPair.public,
crlDistPoint = doormanCRL.toURI().toString()
)
doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, otherCert)
doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, otherCert)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
doRevocationCheck()
}
private fun doRevocationCheck() {
trustManager.checkClientTrusted(chain, "ECDHE_ECDSA")
}
}

View File

@ -270,8 +270,6 @@ tasks.register('integrationTest', Test) {
testClassesDirs = sourceSets.integrationTest.output.classesDirs
classpath = sourceSets.integrationTest.runtimeClasspath
maxParallelForks = (System.env.CORDA_NODE_INT_TESTING_FORKS == null) ? 1 : "$System.env.CORDA_NODE_INT_TESTING_FORKS".toInteger()
// CertificateRevocationListNodeTests
systemProperty 'net.corda.dpcrl.connect.timeout', '4000'
}
tasks.register('slowIntegrationTest', Test) {

View File

@ -15,12 +15,15 @@ import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionResult
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.initialiseTrustStoreAndEnableCrlChecking
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.testing.internal.fixedCrlSource
import org.junit.Assume.assumeFalse
import org.junit.Before
import org.junit.Rule
@ -98,11 +101,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) {
override val maxMessageSize: Int = MAX_MESSAGE_SIZE
}
serverKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
serverTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
serverKeyManagerFactory = keyManagerFactory(keyStore)
serverKeyManagerFactory.init(keyStore)
serverTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(serverAmqpConfig.trustStore, serverAmqpConfig.revocationConfig))
serverTrustManagerFactory = trustManagerFactoryWithRevocation(
serverAmqpConfig.trustStore,
RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL),
fixedCrlSource(emptySet())
)
}
private fun setupClientCertificates() {
@ -129,11 +134,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) {
override val sslHandshakeTimeout: Duration = 3.seconds
}
clientKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
clientTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
clientKeyManagerFactory = keyManagerFactory(keyStore)
clientKeyManagerFactory.init(keyStore)
clientTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(clientAmqpConfig.trustStore, clientAmqpConfig.revocationConfig))
clientTrustManagerFactory = trustManagerFactoryWithRevocation(
clientAmqpConfig.trustStore,
RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL),
fixedCrlSource(emptySet())
)
}
@Test(timeout = 300_000)

View File

@ -1,3 +1,5 @@
@file:Suppress("LongParameterList")
package net.corda.node.amqp
import com.nhaarman.mockito_kotlin.doReturn
@ -5,10 +7,10 @@ import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.crypto.Crypto
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div
import net.corda.core.internal.rootCause
import net.corda.core.internal.times
import net.corda.core.toFuture
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.node.services.config.NodeConfiguration
@ -18,64 +20,68 @@ import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.config.CertificateStoreSupplier
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_CA
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionChange
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.CHARLIE_NAME
import net.corda.testing.core.MAX_MESSAGE_SIZE
import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.testing.node.internal.network.CrlServer
import net.corda.testing.node.internal.network.CrlServer.Companion.EMPTY_CRL
import net.corda.testing.node.internal.network.CrlServer.Companion.FORBIDDEN_CRL
import net.corda.testing.node.internal.network.CrlServer.Companion.NODE_CRL
import net.corda.testing.node.internal.network.CrlServer.Companion.withCrlDistPoint
import org.apache.activemq.artemis.api.core.QueueConfiguration
import org.apache.activemq.artemis.api.core.RoutingType
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import java.net.SocketTimeoutException
import java.io.Closeable
import java.security.cert.X509Certificate
import java.time.Duration
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.assertEquals
import java.util.stream.IntStream
@Suppress("LongParameterList")
class CertificateRevocationListNodeTests {
abstract class AbstractServerRevocationTest {
@Rule
@JvmField
val temporaryFolder = TemporaryFolder()
private val portAllocation = incrementalPortAllocation()
private val serverPort = portAllocation.nextPort()
protected val serverPort = portAllocation.nextPort()
private lateinit var crlServer: CrlServer
private lateinit var amqpServer: AMQPServer
private lateinit var amqpClient: AMQPClient
protected lateinit var crlServer: CrlServer
private val amqpClients = ArrayList<AMQPClient>()
private abstract class AbstractNodeConfiguration : NodeConfiguration
protected lateinit var defaultCrlDistPoints: CrlDistPoints
protected abstract class AbstractNodeConfiguration : NodeConfiguration
companion object {
private val unreachableIpCounter = AtomicInteger(1)
private val crlConnectTimeout = Duration.ofMillis(System.getProperty("net.corda.dpcrl.connect.timeout").toLong())
val crlConnectTimeout = 2.seconds
/**
* Use this method to get a unqiue unreachable IP address. Subsequent uses of the same IP for connection timeout testing purposes
* may not work as the OS process may cache the timeout result.
*/
private fun newUnreachableIpAddress(): String {
private fun newUnreachableIpAddress(): NetworkHostAndPort {
check(unreachableIpCounter.get() != 255)
return "10.255.255.${unreachableIpCounter.getAndIncrement()}"
return NetworkHostAndPort("10.255.255", unreachableIpCounter.getAndIncrement())
}
}
@ -85,252 +91,190 @@ class CertificateRevocationListNodeTests {
Crypto.findProvider(BouncyCastleProvider.PROVIDER_NAME)
crlServer = CrlServer(NetworkHostAndPort("localhost", 0))
crlServer.start()
defaultCrlDistPoints = CrlDistPoints(crlServer.hostAndPort)
}
@After
fun tearDown() {
if (::amqpClient.isInitialized) {
amqpClient.close()
}
if (::amqpServer.isInitialized) {
amqpServer.close()
}
amqpClients.parallelStream().forEach(AMQPClient::close)
if (::crlServer.isInitialized) {
crlServer.close()
}
}
@Test(timeout=300_000)
fun `AMQP server connection works and soft fail is enabled`() {
verifyAMQPConnection(
fun `connection succeeds when soft fail is enabled`() {
verifyConnection(
crlCheckSoftFail = true,
expectedConnectStatus = true
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `AMQP server connection works and soft fail is disabled`() {
verifyAMQPConnection(
fun `connection succeeds when soft fail is disabled`() {
verifyConnection(
crlCheckSoftFail = false,
expectedConnectStatus = true
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `AMQP server connection fails when client's certificate is revoked and soft fail is enabled`() {
verifyAMQPConnection(
fun `connection fails when client's certificate is revoked and soft fail is enabled`() {
verifyConnection(
crlCheckSoftFail = true,
revokeClientCert = true,
expectedConnectStatus = false
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `AMQP server connection fails when client's certificate is revoked and soft fail is disabled`() {
verifyAMQPConnection(
fun `connection fails when client's certificate is revoked and soft fail is disabled`() {
verifyConnection(
crlCheckSoftFail = false,
revokeClientCert = true,
expectedConnectStatus = false
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `AMQP server connection fails when servers's certificate is revoked and soft fail is enabled`() {
verifyAMQPConnection(
fun `connection fails when server's certificate is revoked and soft fail is enabled`() {
verifyConnection(
crlCheckSoftFail = true,
revokeServerCert = true,
expectedConnectStatus = false
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `AMQP server connection fails when servers's certificate is revoked and soft fail is disabled`() {
verifyAMQPConnection(
fun `connection fails when server's certificate is revoked and soft fail is disabled`() {
verifyConnection(
crlCheckSoftFail = false,
revokeServerCert = true,
expectedConnectStatus = false
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL cannot be obtained and soft fail is enabled`() {
verifyAMQPConnection(
fun `connection succeeds when CRL cannot be obtained and soft fail is enabled`() {
verifyConnection(
crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/invalid.crl",
expectedConnectStatus = true
clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"),
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `AMQP server connection fails when CRL cannot be obtained and soft fail is disabled`() {
verifyAMQPConnection(
fun `connection fails when CRL cannot be obtained and soft fail is disabled`() {
verifyConnection(
crlCheckSoftFail = false,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/invalid.crl",
expectedConnectStatus = false
clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"),
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL is not defined and soft fail is enabled`() {
verifyAMQPConnection(
fun `connection succeeds when CRL is not defined for node CA cert and soft fail is enabled`() {
verifyConnection(
crlCheckSoftFail = true,
nodeCrlDistPoint = null,
expectedConnectStatus = true
clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null),
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `AMQP server connection fails when CRL is not defined and soft fail is disabled`() {
verifyAMQPConnection(
fun `connection fails when CRL is not defined for node CA cert and soft fail is disabled`() {
verifyConnection(
crlCheckSoftFail = false,
nodeCrlDistPoint = null,
expectedConnectStatus = false
clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null),
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL retrieval is forbidden and soft fail is enabled`() {
verifyAMQPConnection(
fun `connection succeeds when CRL is not defined for TLS cert and soft fail is enabled`() {
verifyConnection(
crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL",
expectedConnectStatus = true
clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null),
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() {
verifyAMQPConnection(
crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout * 3,
expectedConnectStatus = true
fun `connection fails when CRL is not defined for TLS cert and soft fail is disabled`() {
verifyConnection(
crlCheckSoftFail = false,
clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null),
expectedConnectedStatus = false
)
val timeoutExceptions = (amqpServer.softFailExceptions + amqpClient.softFailExceptions)
.map { it.rootCause }
.filterIsInstance<SocketTimeoutException>()
assertThat(timeoutExceptions).isNotEmpty
}
@Test(timeout=300_000)
fun `AMQP server connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() {
verifyAMQPConnection(
fun `connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() {
verifyConnection(
crlCheckSoftFail = true,
sslHandshakeTimeout = crlConnectTimeout * 4,
clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()),
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() {
verifyConnection(
crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout / 2,
expectedConnectStatus = false
clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()),
expectedConnectedStatus = false
)
}
@Test(timeout=300_000)
fun `verify CRL algorithms`() {
val crl = crlServer.createRevocationList(
"SHA256withECDSA",
crlServer.rootCa,
EMPTY_CRL,
true,
emptyList()
@Test(timeout = 300_000)
fun `influx of new clients during CRL endpoint downtime does not cause existing connections to drop`() {
val serverCrlSource = CertDistPointCrlSource()
// Start the server and verify the first client has connected
val firstClientConnectionChangeStatus = verifyConnection(
crlCheckSoftFail = true,
crlSource = serverCrlSource,
// In general, N remoting threads will naturally support N-1 new handshaking clients plus one thread for heartbeating with
// existing clients. The trick is to make sure at least N new clients are also supported.
remotingThreads = 2,
expectedConnectedStatus = true
)
// This should pass.
crl.verify(crlServer.rootCa.keyPair.public)
// Try changing the algorithm to EC will fail.
assertThatIllegalArgumentException().isThrownBy {
crlServer.createRevocationList(
"EC",
crlServer.rootCa,
EMPTY_CRL,
true,
emptyList()
// Now simulate the CRL endpoint becoming very slow/unreachable
crlServer.delay = 10.minutes
// And pretend enough time has elapsed that the cached CRLs have expired and need downloading again
serverCrlSource.clearCache()
// Now a bunch of new clients have arrived and want to handshake with the server, which will potentially cause the server's Netty
// threads to be tied up in trying to download the CRLs.
IntStream.range(0, 2).parallel().forEach { clientIndex ->
val (newClient, _) = createAMQPClient(
serverPort,
crlCheckSoftFail = true,
legalName = CordaX500Name("NewClient$clientIndex", "London", "GB"),
crlDistPoints = defaultCrlDistPoints
)
}.withMessage("Unknown signature type requested: EC")
newClient.start()
}
// Make sure there are no further connection change updates, i.e. the first client stays connected throughout this whole saga
assertThat(firstClientConnectionChangeStatus.poll(30, TimeUnit.SECONDS)).isNull()
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with soft fail CRL check`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged
)
}
protected abstract fun verifyConnection(crlCheckSoftFail: Boolean,
crlSource: CertDistPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout),
sslHandshakeTimeout: Duration? = null,
remotingThreads: Int? = null,
clientCrlDistPoints: CrlDistPoints = defaultCrlDistPoints,
revokeClientCert: Boolean = false,
revokeServerCert: Boolean = false,
expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange>
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with hard fail CRL check`() {
verifyArtemisConnection(
crlCheckSoftFail = false,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged
)
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with soft fail CRL check on unavailable URL`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL"
)
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with soft fail CRL check on unreachable URL if CRL timeout is within SSL handshake timeout`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout * 3
)
}
@Test(timeout = 300_000)
fun `Artemis server connection fails with soft fail CRL check on unreachable URL if CRL timeout is not within SSL handshake timeout`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedConnected = false,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout / 2
)
}
@Test(timeout = 300_000)
fun `Artemis server connection fails with hard fail CRL check on unavailable URL`() {
verifyArtemisConnection(
crlCheckSoftFail = false,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Rejected,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL"
)
}
@Test(timeout = 300_000)
fun `Artemis server connection fails with soft fail CRL check on revoked node certificate`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Rejected,
revokedNodeCert = true
)
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with disabled CRL check on revoked node certificate`() {
verifyArtemisConnection(
crlCheckSoftFail = false,
crlCheckArtemisServer = false,
expectedStatus = MessageStatus.Acknowledged,
revokedNodeCert = true
)
}
private fun createAMQPClient(targetPort: Int,
crlCheckSoftFail: Boolean,
legalName: CordaX500Name,
nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
tlsCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$EMPTY_CRL",
maxMessageSize: Int = MAX_MESSAGE_SIZE): X509Certificate {
protected fun createAMQPClient(targetPort: Int,
crlCheckSoftFail: Boolean,
legalName: CordaX500Name,
crlDistPoints: CrlDistPoints): Pair<AMQPClient, X509Certificate> {
val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation
val certificatesDirectory = baseDirectory / "certificates"
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory)
@ -344,31 +288,128 @@ class CertificateRevocationListNodeTests {
doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail
}
clientConfig.configureWithDevSSLCertificate()
val nodeCert = recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, tlsCrlDistPoint)
val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer)
val keyStore = clientConfig.p2pSslOptions.keyStore.get()
val amqpConfig = object : AMQPConfiguration {
override val keyStore = keyStore
override val trustStore = clientConfig.p2pSslOptions.trustStore.get()
override val maxMessageSize: Int = maxMessageSize
override val maxMessageSize: Int = MAX_MESSAGE_SIZE
override val trace: Boolean = true
}
amqpClient = AMQPClient(
val amqpClient = AMQPClient(
listOf(NetworkHostAndPort("localhost", targetPort)),
setOf(CHARLIE_NAME),
amqpConfig,
threadPoolName = legalName.organisation
nettyThreading = AMQPClient.NettyThreading.NonShared(legalName.organisation),
distPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout)
)
amqpClients += amqpClient
return Pair(amqpClient, nodeCert)
}
return nodeCert
protected fun AMQPClient.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange> {
val connectionChangeStatus = LinkedBlockingQueue<ConnectionChange>()
onConnection.subscribe { connectionChangeStatus.add(it) }
start()
assertThat(connectionChangeStatus.take().connected).isEqualTo(expectedConnectedStatus)
return connectionChangeStatus
}
protected data class CrlDistPoints(val crlServerAddress: NetworkHostAndPort,
val nodeCa: String? = NODE_CRL,
val tls: String? = EMPTY_CRL) {
private val nodeCaCertCrlDistPoint: String? get() = nodeCa?.let { "http://$crlServerAddress/crl/$it" }
private val tlsCertCrlDistPoint: String? get() = tls?.let { "http://$crlServerAddress/crl/$it" }
fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier,
p2pSslConfiguration: MutualSslConfiguration,
crlServer: CrlServer): X509Certificate {
val nodeKeyStore = signingCertificateStore.get()
val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_CA, nodeKeyStore.entryPassword) }
val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCertCrlDistPoint)
val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) +
nodeKeyStore.query { getCertificateChain(CORDA_CLIENT_CA) }.drop(2)
nodeKeyStore.update {
internal.deleteEntry(CORDA_CLIENT_CA)
}
nodeKeyStore.update {
setPrivateKey(CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword)
}
val sslKeyStore = p2pSslConfiguration.keyStore.get()
val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_TLS, sslKeyStore.entryPassword) }
val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCertCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal)
val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) +
sslKeyStore.query { getCertificateChain(CORDA_CLIENT_TLS) }.drop(3)
sslKeyStore.update {
internal.deleteEntry(CORDA_CLIENT_TLS)
}
sslKeyStore.update {
setPrivateKey(CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword)
}
return newNodeCert
}
}
}
class AMQPServerRevocationTest : AbstractServerRevocationTest() {
private lateinit var amqpServer: AMQPServer
@After
fun shutDown() {
if (::amqpServer.isInitialized) {
amqpServer.close()
}
}
override fun verifyConnection(crlCheckSoftFail: Boolean,
crlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?,
remotingThreads: Int?,
clientCrlDistPoints: CrlDistPoints,
revokeClientCert: Boolean,
revokeServerCert: Boolean,
expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange> {
val serverCert = createAMQPServer(
serverPort,
CHARLIE_NAME,
crlCheckSoftFail,
defaultCrlDistPoints,
crlSource,
sslHandshakeTimeout,
remotingThreads
)
if (revokeServerCert) {
crlServer.revokedNodeCerts.add(serverCert)
}
amqpServer.start()
amqpServer.onReceive.subscribe {
it.complete(true)
}
val (client, clientCert) = createAMQPClient(
serverPort,
crlCheckSoftFail = crlCheckSoftFail,
legalName = ALICE_NAME,
crlDistPoints = clientCrlDistPoints
)
if (revokeClientCert) {
crlServer.revokedNodeCerts.add(clientCert)
}
return client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus)
}
private fun createAMQPServer(port: Int,
legalName: CordaX500Name,
crlCheckSoftFail: Boolean,
nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
tlsCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$EMPTY_CRL",
maxMessageSize: Int = MAX_MESSAGE_SIZE,
sslHandshakeTimeout: Duration? = null): X509Certificate {
crlDistPoints: CrlDistPoints,
distPointCrlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?,
remotingThreads: Int?): X509Certificate {
check(!::amqpServer.isInitialized)
val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation
val certificatesDirectory = baseDirectory / "certificates"
@ -382,92 +423,103 @@ class CertificateRevocationListNodeTests {
doReturn(signingCertificateStore).whenever(it).signingCertificateStore
}
serverConfig.configureWithDevSSLCertificate()
val nodeCert = recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, tlsCrlDistPoint)
val serverCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer)
val keyStore = serverConfig.p2pSslOptions.keyStore.get()
val amqpConfig = object : AMQPConfiguration {
override val keyStore = keyStore
override val trustStore = serverConfig.p2pSslOptions.trustStore.get()
override val revocationConfig = crlCheckSoftFail.toRevocationConfig()
override val maxMessageSize: Int = maxMessageSize
override val maxMessageSize: Int = MAX_MESSAGE_SIZE
override val sslHandshakeTimeout: Duration = sslHandshakeTimeout ?: super.sslHandshakeTimeout
}
amqpServer = AMQPServer("0.0.0.0", port, amqpConfig, threadPoolName = legalName.organisation)
return nodeCert
}
private fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier,
p2pSslConfiguration: MutualSslConfiguration,
nodeCaCrlDistPoint: String?,
tlsCrlDistPoint: String?): X509Certificate {
val nodeKeyStore = signingCertificateStore.get()
val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA, nodeKeyStore.entryPassword) }
val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCrlDistPoint)
val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) +
nodeKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_CA) }.drop(2)
nodeKeyStore.update {
internal.deleteEntry(X509Utilities.CORDA_CLIENT_CA)
}
nodeKeyStore.update {
setPrivateKey(X509Utilities.CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword)
}
val sslKeyStore = p2pSslConfiguration.keyStore.get()
val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_TLS, sslKeyStore.entryPassword) }
val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal)
val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) +
sslKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_TLS) }.drop(3)
sslKeyStore.update {
internal.deleteEntry(X509Utilities.CORDA_CLIENT_TLS)
}
sslKeyStore.update {
setPrivateKey(X509Utilities.CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword)
}
return newNodeCert
}
private fun verifyAMQPConnection(crlCheckSoftFail: Boolean,
nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
revokeServerCert: Boolean = false,
revokeClientCert: Boolean = false,
sslHandshakeTimeout: Duration? = null,
expectedConnectStatus: Boolean) {
val serverCert = createAMQPServer(
serverPort,
CHARLIE_NAME,
crlCheckSoftFail = crlCheckSoftFail,
nodeCrlDistPoint = nodeCrlDistPoint,
sslHandshakeTimeout = sslHandshakeTimeout
amqpServer = AMQPServer(
"0.0.0.0",
port,
amqpConfig,
threadPoolName = legalName.organisation,
distPointCrlSource = distPointCrlSource,
remotingThreads = remotingThreads
)
if (revokeServerCert) {
crlServer.revokedNodeCerts.add(serverCert.serialNumber)
return serverCert
}
}
class ArtemisServerRevocationTest : AbstractServerRevocationTest() {
private lateinit var artemisNode: ArtemisNode
private var crlCheckArtemisServer = true
@After
fun shutDown() {
if (::artemisNode.isInitialized) {
artemisNode.close()
}
amqpServer.start()
amqpServer.onReceive.subscribe {
it.complete(true)
}
val clientCert = createAMQPClient(
}
@Test(timeout = 300_000)
fun `connection succeeds with disabled CRL check on revoked node certificate`() {
crlCheckArtemisServer = false
verifyConnection(
crlCheckSoftFail = false,
revokeClientCert = true,
expectedConnectedStatus = true
)
}
override fun verifyConnection(crlCheckSoftFail: Boolean,
crlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?,
remotingThreads: Int?,
clientCrlDistPoints: CrlDistPoints,
revokeClientCert: Boolean,
revokeServerCert: Boolean,
expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange> {
val (client, clientCert) = createAMQPClient(
serverPort,
crlCheckSoftFail = crlCheckSoftFail,
crlCheckSoftFail = true,
legalName = ALICE_NAME,
nodeCrlDistPoint = nodeCrlDistPoint
crlDistPoints = clientCrlDistPoints
)
if (revokeClientCert) {
crlServer.revokedNodeCerts.add(clientCert.serialNumber)
crlServer.revokedNodeCerts.add(clientCert)
}
val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.start()
val serverConnect = serverConnected.get()
assertThat(serverConnect.connected).isEqualTo(expectedConnectStatus)
val nodeCert = startArtemisNode(
CHARLIE_NAME,
crlCheckSoftFail,
defaultCrlDistPoints,
crlSource,
sslHandshakeTimeout,
remotingThreads
)
if (revokeServerCert) {
crlServer.revokedNodeCerts.add(nodeCert)
}
val queueName = "${P2P_PREFIX}Test"
artemisNode.client.started!!.session.createQueue(
QueueConfiguration(queueName).setRoutingType(RoutingType.ANYCAST).setAddress(queueName).setDurable(true)
)
val clientConnectionChangeStatus = client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus)
if (expectedConnectedStatus) {
val msg = client.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap())
client.write(msg)
assertThat(msg.onComplete.get()).isEqualTo(MessageStatus.Acknowledged)
}
return clientConnectionChangeStatus
}
private fun createArtemisServerAndClient(legalName: CordaX500Name,
crlCheckSoftFail: Boolean,
crlCheckArtemisServer: Boolean,
nodeCrlDistPoint: String,
sslHandshakeTimeout: Duration?): Pair<ArtemisMessagingServer, ArtemisMessagingClient> {
val baseDirectory = temporaryFolder.root.toPath() / "artemis"
private fun startArtemisNode(legalName: CordaX500Name,
crlCheckSoftFail: Boolean,
crlDistPoints: CrlDistPoints,
distPointCrlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?,
remotingThreads: Int?): X509Certificate {
check(!::artemisNode.isInitialized)
val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation
val certificatesDirectory = baseDirectory / "certificates"
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, sslHandshakeTimeout = sslHandshakeTimeout)
@ -483,62 +535,34 @@ class CertificateRevocationListNodeTests {
doReturn(crlCheckArtemisServer).whenever(it).crlCheckArtemisServer
}
artemisConfig.configureWithDevSSLCertificate()
recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, null)
val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer)
val server = ArtemisMessagingServer(
artemisConfig,
artemisConfig.p2pAddress,
MAX_MESSAGE_SIZE,
threadPoolName = "${legalName.organisation}-server",
trace = true
trace = true,
distPointCrlSource = distPointCrlSource,
remotingThreads = remotingThreads
)
val client = ArtemisMessagingClient(
artemisConfig.p2pSslOptions,
artemisConfig.p2pAddress,
MAX_MESSAGE_SIZE,
threadPoolName = "${legalName.organisation}-client",
trace = true
threadPoolName = "${legalName.organisation}-client"
)
server.start()
client.start()
return server to client
val artemisNode = ArtemisNode(server, client)
this.artemisNode = artemisNode
return nodeCert
}
private fun verifyArtemisConnection(crlCheckSoftFail: Boolean,
crlCheckArtemisServer: Boolean,
expectedConnected: Boolean = true,
expectedStatus: MessageStatus? = null,
revokedNodeCert: Boolean = false,
nodeCrlDistPoint: String = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
sslHandshakeTimeout: Duration? = null) {
val queueName = P2P_PREFIX + "Test"
val (artemisServer, artemisClient) = createArtemisServerAndClient(
CHARLIE_NAME,
crlCheckSoftFail,
crlCheckArtemisServer,
nodeCrlDistPoint,
sslHandshakeTimeout
)
artemisServer.use {
artemisClient.started!!.session.createQueue(
QueueConfiguration(queueName).setRoutingType(RoutingType.ANYCAST).setAddress(queueName).setDurable(true)
)
val nodeCert = createAMQPClient(serverPort, true, ALICE_NAME, nodeCrlDistPoint)
if (revokedNodeCert) {
crlServer.revokedNodeCerts.add(nodeCert.serialNumber)
}
val clientConnected = amqpClient.onConnection.toFuture()
amqpClient.start()
val clientConnect = clientConnected.get()
assertThat(clientConnect.connected).isEqualTo(expectedConnected)
if (expectedConnected) {
val msg = amqpClient.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap())
amqpClient.write(msg)
assertEquals(expectedStatus, msg.onComplete.get())
}
artemisClient.stop()
private class ArtemisNode(val server: ArtemisMessagingServer, val client: ArtemisMessagingClient) : Closeable {
override fun close() {
client.stop()
server.close()
}
}
}

View File

@ -4,12 +4,15 @@ import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
import io.netty.channel.EventLoopGroup
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div
import net.corda.core.toFuture
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.node.services.messaging.ArtemisMessagingServer
@ -23,7 +26,9 @@ import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
@ -31,9 +36,6 @@ import net.corda.testing.core.CHARLIE_NAME
import net.corda.testing.core.MAX_MESSAGE_SIZE
import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import org.apache.activemq.artemis.api.core.QueueConfiguration
import org.apache.activemq.artemis.api.core.RoutingType
import org.assertj.core.api.Assertions
@ -44,7 +46,11 @@ import org.junit.Test
import org.junit.rules.TemporaryFolder
import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit
import javax.net.ssl.*
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLHandshakeException
import javax.net.ssl.SSLParameters
import javax.net.ssl.SSLServerSocket
import javax.net.ssl.SSLSocket
import kotlin.concurrent.thread
import kotlin.test.assertEquals
import kotlin.test.assertTrue
@ -146,15 +152,10 @@ class ProtonWrapperTests {
sslConfig.keyStore.get(true).also { it.registerDevP2pCertificates(ALICE_NAME, rootCa.certificate, intermediateCa) }
sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val context = SSLContext.getInstance("TLS")
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustMgrFactory.init(trustStore)
val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get())
val trustManagers = trustMgrFactory.trustManagers
context.init(keyManagers, trustManagers, newSecureRandom())
@ -442,7 +443,7 @@ class ProtonWrapperTests {
amqpServer.use {
val connectionEvents = amqpServer.onConnection.toBlocking().iterator
amqpServer.start()
val sharedThreads = NioEventLoopGroup()
val sharedThreads = NioEventLoopGroup(DefaultThreadFactory("sharedThreads"))
val amqpClient1 = createSharedThreadsClient(sharedThreads, 0)
val amqpClient2 = createSharedThreadsClient(sharedThreads, 1)
amqpClient1.start()
@ -608,7 +609,7 @@ class ProtonWrapperTests {
listOf(NetworkHostAndPort("localhost", serverPort)),
setOf(ALICE_NAME),
amqpConfig,
sharedThreadPool = sharedEventGroup)
nettyThreading = AMQPClient.NettyThreading.Shared(sharedEventGroup))
}
private fun createServer(port: Int,

View File

@ -3,6 +3,8 @@ package net.corda.services.messaging
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.apache.qpid.jms.JmsConnectionFactory
import org.apache.qpid.jms.meta.JmsConnectionInfo
import org.apache.qpid.jms.provider.Provider
@ -24,9 +26,7 @@ import javax.jms.Connection
import javax.jms.Message
import javax.jms.MessageProducer
import javax.jms.Session
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
/**
* Simple AMQP client connecting to broker using JMS.
@ -59,12 +59,8 @@ class SimpleAMQPClient(private val target: NetworkHostAndPort, private val confi
private lateinit var connection: Connection
private fun sslContext(): SSLContext {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()).apply {
init(config.keyStore.get().value.internal, config.keyStore.entryPassword.toCharArray())
}
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
init(config.trustStore.get().value.internal)
}
val keyManagerFactory = keyManagerFactory(config.keyStore.get())
val trustManagerFactory = trustManagerFactory(config.trustStore.get())
val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers

View File

@ -5,8 +5,8 @@ import com.codahale.metrics.Gauge
import com.codahale.metrics.MetricRegistry
import com.google.common.collect.MutableClassToInstanceMap
import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.ThreadFactoryBuilder
import com.zaxxer.hikari.pool.HikariPool
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.common.logging.errorReporting.NodeDatabaseErrors
import net.corda.confidential.SwapIdentitiesFlow
import net.corda.core.CordaException
@ -73,6 +73,7 @@ import net.corda.core.toFuture
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.days
import net.corda.core.utilities.millis
import net.corda.core.utilities.minutes
import net.corda.djvm.source.ApiSource
import net.corda.djvm.source.EmptyApi
@ -172,6 +173,7 @@ import net.corda.nodeapi.internal.persistence.RestrictedEntityManager
import net.corda.nodeapi.internal.persistence.SchemaMigration
import net.corda.nodeapi.internal.persistence.contextDatabase
import net.corda.nodeapi.internal.persistence.withoutDatabaseAccess
import net.corda.nodeapi.internal.namedThreadPoolExecutor
import org.apache.activemq.artemis.utils.ReusableLatch
import org.jolokia.jvmagent.JolokiaServer
import org.jolokia.jvmagent.JolokiaServerConfig
@ -187,9 +189,6 @@ import java.util.ArrayList
import java.util.Properties
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeUnit.MINUTES
import java.util.concurrent.TimeUnit.SECONDS
import java.util.function.Consumer
@ -353,7 +352,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
private val cordappServices = MutableClassToInstanceMap.create<SerializeAsToken>()
private val cordappTelemetryComponents = MutableClassToInstanceMap.create<TelemetryComponent>()
private val shutdownExecutor = Executors.newSingleThreadExecutor()
private val shutdownExecutor = Executors.newSingleThreadExecutor(DefaultThreadFactory("Shutdown"))
protected abstract val transactionVerifierWorkerCount: Int
/**
@ -808,7 +807,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
} else {
1.days
}
val executor = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("Network Map Updater"))
val executor = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("NetworkMapPublisher"))
executor.submit(object : Runnable {
override fun run() {
val republishInterval = try {
@ -917,13 +916,12 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
}
// Start with 1 thread and scale up to the configured thread pool size if needed
// Parameters of [ThreadPoolExecutor] based on [Executors.newFixedThreadPool]
return ThreadPoolExecutor(
1,
numberOfThreads,
0L,
TimeUnit.MILLISECONDS,
LinkedBlockingQueue<Runnable>(),
ThreadFactoryBuilder().setNameFormat("flow-external-operation-thread").setDaemon(true).build()
return namedThreadPoolExecutor(
corePoolSize = 1,
maxPoolSize = numberOfThreads,
idleKeepAlive = 0.millis,
poolName = "flow-external-operation-thread",
daemonThreads = true
)
}
@ -1174,7 +1172,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
networkParameters: NetworkParameters)
protected open fun makeVaultService(keyManagementService: KeyManagementService,
services: ServicesForResolution,
services: NodeServicesForResolution,
database: CordaPersistence,
cordappLoader: CordappLoader): VaultServiceInternal {
return NodeVaultService(platformClock, keyManagementService, services, database, schemaService, cordappLoader.appClassLoader)

View File

@ -415,12 +415,13 @@ open class Node(configuration: NodeConfiguration,
}
private fun makeBridgeControlListener(serverAddress: NetworkHostAndPort, networkParameters: NetworkParameters): BridgeControlListener {
val artemisMessagingClientFactory = {
val artemisMessagingClientFactory = { threadPoolName: String ->
ArtemisMessagingClient(
configuration.p2pSslOptions,
serverAddress,
networkParameters.maxMessageSize,
failoverCallback = { errorAndTerminate("ArtemisMessagingClient failed. Shutting down.", null) }
failoverCallback = { errorAndTerminate("ArtemisMessagingClient failed. Shutting down.", null) },
threadPoolName = threadPoolName
)
}
return BridgeControlListener(
@ -431,7 +432,8 @@ open class Node(configuration: NodeConfiguration,
networkParameters.maxMessageSize,
configuration.crlCheckSoftFail.toRevocationConfig(),
false,
artemisMessagingClientFactory)
artemisMessagingClientFactory
)
}
private fun startLocalRpcBroker(securityManager: RPCSecurityManager): BrokerAddresses? {

View File

@ -0,0 +1,15 @@
package net.corda.node.internal
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionResolutionException
import net.corda.core.node.ServicesForResolution
import java.util.LinkedHashSet
interface NodeServicesForResolution : ServicesForResolution {
@Throws(TransactionResolutionException::class)
override fun loadStates(stateRefs: Set<StateRef>): Set<StateAndRef<ContractState>> = loadStates(stateRefs, LinkedHashSet())
fun <T : ContractState, C : MutableCollection<StateAndRef<T>>> loadStates(input: Iterable<StateRef>, output: C): C
}

View File

@ -1,16 +1,26 @@
package net.corda.node.internal
import net.corda.core.contracts.*
import net.corda.core.contracts.Attachment
import net.corda.core.contracts.AttachmentResolutionException
import net.corda.core.contracts.ContractAttachment
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionResolutionException
import net.corda.core.contracts.TransactionState
import net.corda.core.cordapp.CordappProvider
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.SerializedStateAndRef
import net.corda.core.internal.uncheckedCast
import net.corda.core.node.NetworkParameters
import net.corda.core.node.ServicesForResolution
import net.corda.core.node.services.AttachmentStorage
import net.corda.core.node.services.IdentityService
import net.corda.core.node.services.NetworkParametersService
import net.corda.core.node.services.TransactionStorage
import net.corda.core.transactions.BaseTransaction
import net.corda.core.transactions.ContractUpgradeWireTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction
import net.corda.core.transactions.WireTransaction.Companion.resolveStateRefBinaryComponent
@ -20,31 +30,28 @@ data class ServicesForResolutionImpl(
override val cordappProvider: CordappProvider,
override val networkParametersService: NetworkParametersService,
private val validatedTransactions: TransactionStorage
) : ServicesForResolution {
) : NodeServicesForResolution {
override val networkParameters: NetworkParameters get() = networkParametersService.lookup(networkParametersService.currentHash) ?:
throw IllegalArgumentException("No current parameters in network parameters storage")
@Throws(TransactionResolutionException::class)
override fun loadState(stateRef: StateRef): TransactionState<*> {
val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash)
return stx.resolveBaseTransaction(this).outputs[stateRef.index]
return toBaseTransaction(stateRef.txhash).outputs[stateRef.index]
}
@Throws(TransactionResolutionException::class)
override fun loadStates(stateRefs: Set<StateRef>): Set<StateAndRef<ContractState>> {
return stateRefs.groupBy { it.txhash }.flatMap {
val stx = validatedTransactions.getTransaction(it.key) ?: throw TransactionResolutionException(it.key)
val baseTx = stx.resolveBaseTransaction(this)
it.value.map { ref -> StateAndRef(baseTx.outputs[ref.index], ref) }
}.toSet()
override fun <T : ContractState, C : MutableCollection<StateAndRef<T>>> loadStates(input: Iterable<StateRef>, output: C): C {
val baseTxs = HashMap<SecureHash, BaseTransaction>()
return input.mapTo(output) { stateRef ->
val baseTx = baseTxs.computeIfAbsent(stateRef.txhash, ::toBaseTransaction)
StateAndRef(uncheckedCast(baseTx.outputs[stateRef.index]), stateRef)
}
}
@Throws(TransactionResolutionException::class, AttachmentResolutionException::class)
override fun loadContractAttachment(stateRef: StateRef): Attachment {
// We may need to recursively chase transactions if there are notary changes.
fun inner(stateRef: StateRef, forContractClassName: String?): Attachment {
val ctx = validatedTransactions.getTransaction(stateRef.txhash)?.coreTransaction
?: throw TransactionResolutionException(stateRef.txhash)
val ctx = getSignedTransaction(stateRef.txhash).coreTransaction
when (ctx) {
is WireTransaction -> {
val transactionState = ctx.outRef<ContractState>(stateRef.index).state
@ -69,4 +76,10 @@ data class ServicesForResolutionImpl(
}
return inner(stateRef, null)
}
private fun toBaseTransaction(txhash: SecureHash): BaseTransaction = getSignedTransaction(txhash).resolveBaseTransaction(this)
private fun getSignedTransaction(txhash: SecureHash): SignedTransaction {
return validatedTransactions.getTransaction(txhash) ?: throw TransactionResolutionException(txhash)
}
}

View File

@ -135,12 +135,12 @@ class BrokerJaasLoginModule : BaseBrokerJaasLoginModule() {
Pair(ArtemisMessagingComponent.NODE_RPC_USER, listOf(RolePrincipal(NODE_RPC_ROLE)))
}
ArtemisMessagingComponent.PEER_USER -> {
requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." }
val p2pJaasConfig = requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." }
requireTls(certificates)
// This check is redundant as it was performed already during the SSL handshake
CertificateChainCheckPolicy.RootMustMatch.createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates)
CertificateChainCheckPolicy.RevocationCheck(p2pJaasConfig!!.revocationMode)
.createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates)
CertificateChainCheckPolicy.RootMustMatch
.createCheck(p2pJaasConfig.keyStore, p2pJaasConfig.trustStore)
.checkCertificateChain(certificates)
Pair(certificates.first().subjectDN.name, listOf(RolePrincipal(PEER_ROLE)))
}
else -> {

View File

@ -2,17 +2,9 @@ package net.corda.node.internal.artemis
import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.certPathToString
import java.security.KeyStore
import java.security.cert.CertPathValidator
import java.security.cert.CertPathValidatorException
import java.security.cert.CertificateException
import java.security.cert.PKIXBuilderParameters
import java.security.cert.X509CertSelector
sealed class CertificateChainCheckPolicy {
companion object {
@ -92,33 +84,4 @@ sealed class CertificateChainCheckPolicy {
}
}
}
class RevocationCheck(val revocationConfig: RevocationConfig) : CertificateChainCheckPolicy() {
constructor(revocationMode: RevocationConfig.Mode) : this(RevocationConfigImpl(revocationMode))
override fun createCheck(keyStore: KeyStore, trustStore: KeyStore): Check {
return object : Check {
override fun checkCertificateChain(theirChain: Array<java.security.cert.X509Certificate>) {
// Convert javax.security.cert.X509Certificate to java.security.cert.X509Certificate.
val chain = theirChain.map { X509CertificateFactory().generateCertificate(it.encoded.inputStream()) }
log.info("Check Client Certpath:\r\n${certPathToString(chain.toTypedArray())}")
// Drop the last certificate which must be a trusted root (validated by RootMustMatch).
// Assume that there is no more trusted roots (or corresponding public keys) in the remaining chain.
// See PKIXValidator.engineValidate() for reference implementation.
val certPath = X509Utilities.buildCertPath(chain.dropLast(1))
val certPathValidator = CertPathValidator.getInstance("PKIX")
val pkixRevocationChecker = revocationConfig.createPKIXRevocationChecker()
val params = PKIXBuilderParameters(trustStore, X509CertSelector())
params.addCertPathChecker(pkixRevocationChecker)
try {
certPathValidator.validate(certPath, params)
} catch (ex: CertPathValidatorException) {
log.error("Bad certificate path", ex)
throw ex
}
}
}
}
}
}

View File

@ -2,7 +2,6 @@ package net.corda.node.migration
import liquibase.database.Database
import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name
import net.corda.core.node.services.Vault
import net.corda.core.schemas.MappedSchema
@ -19,6 +18,7 @@ import net.corda.node.services.persistence.DBTransactionStorage
import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSchemaV1
import net.corda.node.services.vault.toStateRef
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseTransaction
import net.corda.nodeapi.internal.persistence.SchemaMigration
@ -62,8 +62,7 @@ class VaultStateMigration : CordaMigration() {
private fun getStateAndRef(persistentState: VaultSchemaV1.VaultStates): StateAndRef<ContractState> {
val persistentStateRef = persistentState.stateRef ?:
throw VaultStateMigrationException("Persistent state ref missing from state")
val txHash = SecureHash.create(persistentStateRef.txId)
val stateRef = StateRef(txHash, persistentStateRef.index)
val stateRef = persistentStateRef.toStateRef()
val state = try {
servicesForResolution.loadState(stateRef)
} catch (e: Exception) {

View File

@ -2,6 +2,7 @@ package net.corda.node.services.events
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.context.InvocationOrigin
@ -148,7 +149,7 @@ class NodeSchedulerService(private val clock: CordaClock,
// from the database
private val startingStateRefs: MutableSet<ScheduledStateRef> = ConcurrentHashMap.newKeySet<ScheduledStateRef>()
private val mutex = ThreadBox(InnerState())
private val schedulerTimerExecutor = Executors.newSingleThreadExecutor()
private val schedulerTimerExecutor = Executors.newSingleThreadExecutor(DefaultThreadFactory("SchedulerService"))
// if there's nothing to do, check every minute if something fell through the cracks.
// any new state should trigger a reschedule immediately if nothing is scheduled, so I would not expect

View File

@ -2,8 +2,8 @@ package net.corda.node.services.events
import net.corda.core.contracts.ScheduledStateRef
import net.corda.core.contracts.StateRef
import net.corda.core.crypto.SecureHash
import net.corda.core.schemas.PersistentStateRef
import net.corda.node.services.vault.toStateRef
import net.corda.nodeapi.internal.persistence.CordaPersistence
interface ScheduledFlowRepository {
@ -25,9 +25,8 @@ class PersistentScheduledFlowRepository(val database: CordaPersistence) : Schedu
}
private fun fromPersistentEntity(scheduledStateRecord: NodeSchedulerService.PersistentScheduledState): Pair<StateRef, ScheduledStateRef> {
val txId = scheduledStateRecord.output.txId
val index = scheduledStateRecord.output.index
return Pair(StateRef(SecureHash.create(txId), index), ScheduledStateRef(StateRef(SecureHash.create(txId), index), scheduledStateRecord.scheduledAt))
val stateRef = scheduledStateRecord.output.toStateRef()
return Pair(stateRef, ScheduledStateRef(stateRef, scheduledStateRecord.scheduledAt))
}
override fun delete(key: StateRef): Boolean {

View File

@ -7,9 +7,16 @@ import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.node.internal.artemis.*
import net.corda.node.internal.artemis.ArtemisBroker
import net.corda.node.internal.artemis.BrokerAddresses
import net.corda.node.internal.artemis.BrokerJaasLoginModule
import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.NODE_P2P_ROLE
import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.PEER_ROLE
import net.corda.node.internal.artemis.NodeJaasConfig
import net.corda.node.internal.artemis.P2PJaasConfig
import net.corda.node.internal.artemis.SecureArtemisConfiguration
import net.corda.node.internal.artemis.UserValidationPlugin
import net.corda.node.internal.artemis.isBindingError
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.utilities.artemis.startSynchronously
import net.corda.nodeapi.internal.AmqpMessageSizeChecksInterceptor
@ -21,7 +28,10 @@ import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.SECURITY_INVALIDATION_INTERVAL
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pAcceptorTcpTransport
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.nodeapi.internal.requireOnDefaultFileSystem
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl
@ -32,9 +42,7 @@ import org.apache.activemq.artemis.core.security.Role
import org.apache.activemq.artemis.core.server.ActiveMQServer
import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl
import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager
import java.io.IOException
import java.lang.Long.max
import java.security.KeyStoreException
import javax.annotation.concurrent.ThreadSafe
import javax.security.auth.login.AppConfigurationEntry
import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED
@ -57,8 +65,10 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
private val messagingServerAddress: NetworkHostAndPort,
private val maxMessageSize: Int,
private val journalBufferTimeout : Int? = null,
private val threadPoolName: String = "ArtemisServer",
private val trace: Boolean = false) : ArtemisBroker, SingletonSerializeAsToken() {
private val threadPoolName: String = "P2PServer",
private val trace: Boolean = false,
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON,
private val remotingThreads: Int? = null) : ArtemisBroker, SingletonSerializeAsToken() {
companion object {
private val log = contextLogger()
}
@ -92,7 +102,7 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
override val started: Boolean
get() = activeMQServer.isStarted
@Throws(IOException::class, AddressBindingException::class, KeyStoreException::class)
@Suppress("ThrowsCount")
private fun configureAndStartServer() {
val artemisConfig = createArtemisConfig()
val securityManager = createArtemisSecurityManager()
@ -132,11 +142,23 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
// The transaction cache is configurable, and drives other cache sizes.
globalMaxSize = max(config.transactionCacheSizeBytes, 10L * maxMessageSize)
val revocationMode = if (config.crlCheckArtemisServer) {
if (config.crlCheckSoftFail) RevocationConfig.Mode.SOFT_FAIL else RevocationConfig.Mode.HARD_FAIL
} else {
RevocationConfig.Mode.OFF
}
val trustManagerFactory = trustManagerFactoryWithRevocation(
config.p2pSslOptions.trustStore.get(),
RevocationConfigImpl(revocationMode),
distPointCrlSource
)
addAcceptorConfiguration(p2pAcceptorTcpTransport(
NetworkHostAndPort(messagingServerAddress.host, messagingServerAddress.port),
config.p2pSslOptions,
trustManagerFactory,
threadPoolName = threadPoolName,
trace = trace
trace = trace,
remotingThreads = remotingThreads
))
// Enable built in message deduplication. Note we still have to do our own as the delayed commits
// and our own definition of commit mean that the built in deduplication cannot remove all duplicates.

View File

@ -10,6 +10,8 @@ import io.netty.handler.ssl.SslHandshakeTimeoutException
import net.corda.core.internal.declaredField
import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor
import net.corda.nodeapi.internal.setThreadPoolName
import org.apache.activemq.artemis.api.core.BaseInterceptor
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor
import org.apache.activemq.artemis.core.server.balancing.RedirectHandler
@ -19,14 +21,18 @@ import org.apache.activemq.artemis.spi.core.remoting.Acceptor
import org.apache.activemq.artemis.spi.core.remoting.AcceptorFactory
import org.apache.activemq.artemis.spi.core.remoting.BufferHandler
import org.apache.activemq.artemis.spi.core.remoting.ServerConnectionLifeCycleListener
import org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactoryProvider
import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactoryProvider
import org.apache.activemq.artemis.utils.ConfigurationHelper
import org.apache.activemq.artemis.utils.actors.OrderedExecutor
import java.net.SocketAddress
import java.nio.channels.ClosedChannelException
import java.time.Duration
import java.util.concurrent.Executor
import java.util.concurrent.ScheduledExecutorService
import java.util.regex.Pattern
import javax.net.ssl.SSLEngine
import javax.net.ssl.SSLPeerUnverifiedException
@Suppress("unused") // Used via reflection in ArtemisTcpTransport
class NodeNettyAcceptorFactory : AcceptorFactory {
@ -36,10 +42,23 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
handler: BufferHandler?,
listener: ServerConnectionLifeCycleListener?,
threadPool: Executor,
scheduledThreadPool: ScheduledExecutorService?,
scheduledThreadPool: ScheduledExecutorService,
protocolMap: MutableMap<String, ProtocolManager<BaseInterceptor<*>, RedirectHandler<*>>>?): Acceptor {
val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Acceptor", configuration)
threadPool.setThreadPoolName("$threadPoolName-artemis")
scheduledThreadPool.setThreadPoolName("$threadPoolName-artemis-scheduler")
val failureExecutor = OrderedExecutor(threadPool)
return NodeNettyAcceptor(name, clusterConnection, configuration, handler, listener, scheduledThreadPool, failureExecutor, protocolMap)
return NodeNettyAcceptor(
name,
clusterConnection,
configuration,
handler,
listener,
scheduledThreadPool,
failureExecutor,
protocolMap,
"$threadPoolName-netty"
)
}
@ -50,14 +69,21 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
listener: ServerConnectionLifeCycleListener?,
scheduledThreadPool: ScheduledExecutorService?,
failureExecutor: Executor,
protocolMap: MutableMap<String, ProtocolManager<BaseInterceptor<*>, RedirectHandler<*>>>?) :
protocolMap: MutableMap<String, ProtocolManager<BaseInterceptor<*>, RedirectHandler<*>>>?,
private val threadPoolName: String) :
NettyAcceptor(name, clusterConnection, configuration, handler, listener, scheduledThreadPool, failureExecutor, protocolMap)
{
companion object {
private val defaultThreadPoolNamePattern = Pattern.compile("""Thread-(\d+) \(activemq-netty-threads\)""")
init {
// Make sure Artemis isn't using another (Open)SSLContextFactory
check(SSLContextFactoryProvider.getSSLContextFactory() is NodeSSLContextFactory)
check(OpenSSLContextFactoryProvider.getOpenSSLContextFactory() is NodeOpenSSLContextFactory)
}
}
private val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "NodeNettyAcceptor", configuration)
private val sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
private val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration)
@Synchronized
@ -71,11 +97,17 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
}
}
@Synchronized
override fun stop() {
super.stop()
sslDelegatedTaskExecutor.shutdown()
}
@Synchronized
override fun getSslHandler(alloc: ByteBufAllocator?, peerHost: String?, peerPort: Int): SslHandler {
applyThreadPoolName()
val engine = super.getSslHandler(alloc, peerHost, peerPort).engine()
val sslHandler = NodeAcceptorSslHandler(engine, trace)
val sslHandler = NodeAcceptorSslHandler(engine, sslDelegatedTaskExecutor, trace)
val handshakeTimeout = configuration[ArtemisTcpTransport.SSL_HANDSHAKE_TIMEOUT_NAME] as Duration?
if (handshakeTimeout != null) {
sslHandler.handshakeTimeoutMillis = handshakeTimeout.toMillis()
@ -95,13 +127,15 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
}
private class NodeAcceptorSslHandler(engine: SSLEngine, private val trace: Boolean) : SslHandler(engine) {
private class NodeAcceptorSslHandler(engine: SSLEngine,
delegatedTaskExecutor: Executor,
private val trace: Boolean) : SslHandler(engine, delegatedTaskExecutor) {
companion object {
private val logger = contextLogger()
}
override fun handlerAdded(ctx: ChannelHandlerContext) {
logHandshake()
logHandshake(ctx.channel().remoteAddress())
super.handlerAdded(ctx)
// Unfortunately NettyAcceptor does not let us add extra child handlers, so we have to add our logger this way.
if (trace) {
@ -109,17 +143,22 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
}
}
private fun logHandshake() {
private fun logHandshake(remoteAddress: SocketAddress) {
val start = System.currentTimeMillis()
handshakeFuture().addListener {
val duration = System.currentTimeMillis() - start
val peer = try {
engine().session.peerPrincipal
} catch (e: SSLPeerUnverifiedException) {
remoteAddress
}
when {
it.isSuccess -> logger.info("SSL handshake completed in ${duration}ms with ${engine().session.peerPrincipal}")
it.isCancelled -> logger.warn("SSL handshake cancelled after ${duration}ms")
it.isSuccess -> logger.info("SSL handshake completed in ${duration}ms with $peer")
it.isCancelled -> logger.warn("SSL handshake cancelled after ${duration}ms with $peer")
else -> when (it.cause()) {
is ClosedChannelException -> logger.warn("SSL handshake closed early after ${duration}ms")
is SslHandshakeTimeoutException -> logger.warn("SSL handshake timed out after ${duration}ms")
else -> logger.warn("SSL handshake failed after ${duration}ms", it.cause())
is ClosedChannelException -> logger.warn("SSL handshake closed early after ${duration}ms with $peer")
is SslHandshakeTimeoutException -> logger.warn("SSL handshake timed out after ${duration}ms with $peer")
else -> logger.warn("SSL handshake failed after ${duration}ms with $peer", it.cause())
}
}
}

View File

@ -0,0 +1,59 @@
package net.corda.node.services.messaging
import io.netty.handler.ssl.SslContext
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.SslProvider
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.TRUST_MANAGER_FACTORY_NAME
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.createAndInitSslContext
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultOpenSSLContextFactory
import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultSSLContextFactory
import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextConfig
import java.nio.file.Paths
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
class NodeSSLContextFactory : DefaultSSLContextFactory() {
override fun getSSLContext(config: SSLContextConfig, additionalOpts: Map<String, Any>): SSLContext {
val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?
return if (trustManagerFactory != null) {
createAndInitSslContext(loadKeyManagerFactory(config), trustManagerFactory)
} else {
super.getSSLContext(config, additionalOpts)
}
}
override fun getPriority(): Int {
// We make sure this factory is the one that's chosen, so any sufficiently large value will do.
return 15
}
}
class NodeOpenSSLContextFactory : DefaultOpenSSLContextFactory() {
override fun getServerSslContext(config: SSLContextConfig, additionalOpts: Map<String, Any>): SslContext {
val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?
return if (trustManagerFactory != null) {
SslContextBuilder
.forServer(loadKeyManagerFactory(config))
.sslProvider(SslProvider.OPENSSL)
.trustManager(trustManagerFactory)
.build()
} else {
super.getServerSslContext(config, additionalOpts)
}
}
override fun getPriority(): Int {
// We make sure this factory is the one that's chosen, so any sufficiently large value will do.
return 15
}
}
private fun loadKeyManagerFactory(config: SSLContextConfig): KeyManagerFactory {
val keyStore = CertificateStore.fromFile(Paths.get(config.keystorePath), config.keystorePassword, config.keystorePassword, false)
return keyManagerFactory(keyStore)
}

View File

@ -74,7 +74,7 @@ class NetworkMapUpdater(private val networkMapCache: NetworkMapCacheInternal,
}
private val parametersUpdatesTrack = PublishSubject.create<ParametersUpdateInfo>()
private val networkMapPoller = ScheduledThreadPoolExecutor(1, NamedThreadFactory("Network Map Updater Thread")).apply {
private val networkMapPoller = ScheduledThreadPoolExecutor(1, NamedThreadFactory("NetworkMapUpdater")).apply {
executeExistingDelayedTasksAfterShutdownPolicy = false
}
private var newNetworkParameters: Pair<ParametersUpdate, SignedNetworkParameters>? = null
@ -261,9 +261,12 @@ class NetworkMapUpdater(private val networkMapCache: NetworkMapCacheInternal,
//as HTTP GET is mostly IO bound, use more threads than CPU's
//maximum threads to use = 24, as if we did not limit this on large machines it could result in 100's of concurrent requests
val threadsToUseForNetworkMapDownload = min(Runtime.getRuntime().availableProcessors() * 4, 24)
val executorToUseForDownloadingNodeInfos = Executors.newFixedThreadPool(threadsToUseForNetworkMapDownload, NamedThreadFactory("NetworkMapUpdaterNodeInfoDownloadThread"))
val executorToUseForDownloadingNodeInfos = Executors.newFixedThreadPool(
threadsToUseForNetworkMapDownload,
NamedThreadFactory("NetworkMapUpdaterNodeInfoDownload")
)
//DB insert is single threaded - use a single threaded executor for it.
val executorToUseForInsertionIntoDB = Executors.newSingleThreadExecutor(NamedThreadFactory("NetworkMapUpdateDBInsertThread"))
val executorToUseForInsertionIntoDB = Executors.newSingleThreadExecutor(NamedThreadFactory("NetworkMapUpdateDBInsert"))
val hashesToFetch = (allHashesFromNetworkMap - allNodeHashes)
val networkMapDownloadStartTime = System.currentTimeMillis()
if (hashesToFetch.isNotEmpty()) {

View File

@ -22,8 +22,7 @@ class InternalRPCMessagingClient(val sslConfig: MutualSslConfiguration, val serv
private var rpcServer: RPCServer? = null
fun init(rpcOps: List<RPCOps>, securityManager: RPCSecurityManager, cacheFactory: NamedCacheFactory) = synchronized(this) {
val tcpTransport = ArtemisTcpTransport.rpcInternalClientTcpTransport(serverAddress, sslConfig)
val tcpTransport = ArtemisTcpTransport.rpcInternalClientTcpTransport(serverAddress, sslConfig, threadPoolName = "RPCClient")
locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply {
// Never time out on our loopback Artemis connections. If we switch back to using the InVM transport this
// would be the default and the two lines below can be deleted.

View File

@ -30,10 +30,10 @@ internal class RpcBrokerConfiguration(baseDirectory: Path, maxMessageSize: Int,
setDirectories(baseDirectory)
val acceptorConfigurationsSet = mutableSetOf(
rpcAcceptorTcpTransport(address, sslOptions, useSsl)
rpcAcceptorTcpTransport(address, sslOptions, enableSSL = useSsl, threadPoolName = "RPCServer")
)
adminAddress?.let {
acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration)
acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration, threadPoolName = "RPCServerAdmin")
}
acceptorConfigurations = acceptorConfigurationsSet

View File

@ -1,5 +1,6 @@
package net.corda.node.services.statemachine
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.flows.FlowSession
import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.FlowStateMachine
@ -22,10 +23,6 @@ internal class FlowMonitor(
) : LifecycleSupport {
private companion object {
private fun defaultScheduler(): ScheduledExecutorService {
return Executors.newSingleThreadScheduledExecutor()
}
private val logger = loggerFor<FlowMonitor>()
}
@ -36,7 +33,7 @@ internal class FlowMonitor(
override fun start() {
synchronized(this) {
if (scheduler == null) {
scheduler = defaultScheduler()
scheduler = Executors.newSingleThreadScheduledExecutor(DefaultThreadFactory("FlowMonitor"))
shutdownScheduler = true
}
scheduler!!.scheduleAtFixedRate({ logFlowsWaitingForParty() }, 0, monitoringPeriod.toMillis(), TimeUnit.MILLISECONDS)

View File

@ -25,6 +25,7 @@ import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.serialize
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.node.services.vault.toStateRef
import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
@ -157,13 +158,7 @@ class PersistentUniquenessProvider(val clock: Clock, val database: CordaPersiste
toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) },
fromPersistentEntity = {
//TODO null check will become obsolete after making DB/JPA columns not nullable
val txId = it.id.txId
val index = it.id.index
Pair(
StateRef(txhash = SecureHash.create(txId), index = index),
SecureHash.create(it.consumingTxHash)
)
Pair(it.id.toStateRef(), SecureHash.create(it.consumingTxHash))
},
toPersistentEntity = { (txHash, index): StateRef, id: SecureHash ->
CommittedState(

View File

@ -3,28 +3,65 @@ package net.corda.node.services.vault
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import net.corda.core.CordaRuntimeException
import net.corda.core.contracts.*
import net.corda.core.contracts.Amount
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.FungibleAsset
import net.corda.core.contracts.FungibleState
import net.corda.core.contracts.Issued
import net.corda.core.contracts.OwnableState
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionState
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.containsAny
import net.corda.core.flows.HospitalizeFlowException
import net.corda.core.internal.*
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.TransactionDeserialisationException
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.tee
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.DataFeed
import net.corda.core.node.ServicesForResolution
import net.corda.core.node.StatesToRecord
import net.corda.core.node.services.*
import net.corda.core.node.services.Vault.ConstraintInfo.Companion.constraintInfo
import net.corda.core.node.services.vault.*
import net.corda.core.node.services.KeyManagementService
import net.corda.core.node.services.StatesNotAvailableException
import net.corda.core.node.services.Vault
import net.corda.core.node.services.VaultQueryException
import net.corda.core.node.services.VaultService
import net.corda.core.node.services.queryBy
import net.corda.core.node.services.vault.DEFAULT_PAGE_NUM
import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE
import net.corda.core.node.services.vault.PageSpecification
import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.node.services.vault.Sort
import net.corda.core.node.services.vault.SortAttribute
import net.corda.core.node.services.vault.builder
import net.corda.core.observable.internal.OnResilientSubscribe
import net.corda.core.schemas.PersistentStateRef
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.transactions.*
import net.corda.core.utilities.*
import net.corda.core.transactions.ContractUpgradeWireTransaction
import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.FullTransaction
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.core.utilities.toNonEmptySet
import net.corda.core.utilities.trace
import net.corda.node.internal.NodeServicesForResolution
import net.corda.node.services.api.SchemaService
import net.corda.node.services.api.VaultServiceInternal
import net.corda.node.services.schema.PersistentStateService
import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.nodeapi.internal.persistence.*
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
import net.corda.nodeapi.internal.persistence.currentDBSession
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
import org.hibernate.Session
import org.hibernate.query.Query
import rx.Observable
import rx.exceptions.OnErrorNotImplementedException
import rx.subjects.PublishSubject
@ -32,9 +69,11 @@ import java.security.PublicKey
import java.sql.SQLException
import java.time.Clock
import java.time.Instant
import java.util.*
import java.util.Arrays
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArraySet
import java.util.stream.Stream
import javax.persistence.PersistenceException
import javax.persistence.Tuple
import javax.persistence.criteria.CriteriaBuilder
@ -54,9 +93,9 @@ import javax.persistence.criteria.Root
class NodeVaultService(
private val clock: Clock,
private val keyManagementService: KeyManagementService,
private val servicesForResolution: ServicesForResolution,
private val servicesForResolution: NodeServicesForResolution,
private val database: CordaPersistence,
private val schemaService: SchemaService,
schemaService: SchemaService,
private val appClassloader: ClassLoader
) : SingletonSerializeAsToken(), VaultServiceInternal {
companion object {
@ -196,7 +235,7 @@ class NodeVaultService(
if (lockId != null) {
lockId = null
lockUpdateTime = clock.instant()
log.trace("Releasing soft lock on consumed state: $stateRef")
log.trace { "Releasing soft lock on consumed state: $stateRef" }
}
session.save(state)
}
@ -227,7 +266,7 @@ class NodeVaultService(
}
// we are not inside a flow, we are most likely inside a CordaService;
// we will expose, by default, subscribing of -non unsubscribing- rx.Observers to rawUpdates.
return _rawUpdatesPublisher.resilientOnError()
_rawUpdatesPublisher.resilientOnError()
}
override val updates: Observable<Vault.Update<ContractState>>
@ -639,7 +678,23 @@ class NodeVaultService(
@Throws(VaultQueryException::class)
override fun <T : ContractState> _queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): Vault.Page<T> {
try {
return _queryBy(criteria, paging, sorting, contractStateType, false)
// We decrement by one if the client requests MAX_VALUE, assuming they can not notice this because they don't have enough memory
// to request MAX_VALUE states at once.
val validPaging = if (paging.pageSize == Integer.MAX_VALUE) {
paging.copy(pageSize = Integer.MAX_VALUE - 1)
} else {
checkVaultQuery(paging.pageSize >= 1) { "Page specification: invalid page size ${paging.pageSize} [minimum is 1]" }
paging
}
if (!validPaging.isDefault) {
checkVaultQuery(validPaging.pageNumber >= DEFAULT_PAGE_NUM) {
"Page specification: invalid page number ${validPaging.pageNumber} [page numbers start from $DEFAULT_PAGE_NUM]"
}
}
log.debug { "Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $validPaging, sorting: $sorting" }
return database.transaction {
queryBy(criteria, validPaging, sorting, contractStateType)
}
} catch (e: VaultQueryException) {
throw e
} catch (e: Exception) {
@ -647,100 +702,90 @@ class NodeVaultService(
}
}
@Throws(VaultQueryException::class)
private fun <T : ContractState> _queryBy(criteria: QueryCriteria, paging_: PageSpecification, sorting: Sort, contractStateType: Class<out T>, skipPagingChecks: Boolean): Vault.Page<T> {
// We decrement by one if the client requests MAX_PAGE_SIZE, assuming they can not notice this because they don't have enough memory
// to request `MAX_PAGE_SIZE` states at once.
val paging = if (paging_.pageSize == Integer.MAX_VALUE) {
paging_.copy(pageSize = Integer.MAX_VALUE - 1)
} else {
paging_
private fun <T : ContractState> queryBy(criteria: QueryCriteria,
paging: PageSpecification,
sorting: Sort,
contractStateType: Class<out T>): Vault.Page<T> {
// calculate total results where a page specification has been defined
val totalStatesAvailable = if (paging.isDefault) -1 else queryTotalStateCount(criteria, contractStateType)
val (query, stateTypes) = createQuery(criteria, contractStateType, sorting)
query.setResultWindow(paging)
val statesMetadata: MutableList<Vault.StateMetadata> = mutableListOf()
val otherResults: MutableList<Any> = mutableListOf()
query.resultStream(paging).use { results ->
results.forEach { result ->
val result0 = result[0]
if (result0 is VaultSchemaV1.VaultStates) {
statesMetadata.add(result0.toStateMetadata())
} else {
log.debug { "OtherResults: ${Arrays.toString(result.toArray())}" }
otherResults.addAll(result.toArray().asList())
}
}
}
log.debug { "Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $paging, sorting: $sorting" }
return database.transaction {
// calculate total results where a page specification has been defined
var totalStates = -1L
if (!skipPagingChecks && !paging.isDefault) {
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val results = _queryBy(criteria.and(countCriteria), PageSpecification(), Sort(emptyList()), contractStateType, true) // only skip pagination checks for total results count query
totalStates = results.otherResults.last() as Long
}
val session = getSession()
val states: List<StateAndRef<T>> = servicesForResolution.loadStates(
statesMetadata.mapTo(LinkedHashSet()) { it.ref },
ArrayList()
)
val criteriaQuery = criteriaBuilder.createQuery(Tuple::class.java)
val queryRootVaultStates = criteriaQuery.from(VaultSchemaV1.VaultStates::class.java)
// TODO: revisit (use single instance of parser for all queries)
val criteriaParser = HibernateQueryCriteriaParser(contractStateType, contractStateTypeMappings, criteriaBuilder, criteriaQuery, queryRootVaultStates)
// parse criteria and build where predicates
criteriaParser.parse(criteria, sorting)
// prepare query for execution
val query = session.createQuery(criteriaQuery)
// pagination checks
if (!skipPagingChecks && !paging.isDefault) {
// pagination
if (paging.pageNumber < DEFAULT_PAGE_NUM) throw VaultQueryException("Page specification: invalid page number ${paging.pageNumber} [page numbers start from $DEFAULT_PAGE_NUM]")
if (paging.pageSize < 1) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [minimum is 1]")
if (paging.pageSize > MAX_PAGE_SIZE) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [maximum is $MAX_PAGE_SIZE]")
}
// For both SQLServer and PostgresSQL, firstResult must be >= 0. So we set a floor at 0.
// TODO: This is a catch-all solution. But why is the default pageNumber set to be -1 in the first place?
// Even if we set the default pageNumber to be 1 instead, that may not cover the non-default cases.
// So the floor may be necessary anyway.
query.firstResult = maxOf(0, (paging.pageNumber - 1) * paging.pageSize)
val pageSize = paging.pageSize + 1
query.maxResults = if (pageSize > 0) pageSize else Integer.MAX_VALUE // detection too many results, protected against overflow
// execution
val results = query.resultList
return Vault.Page(states, statesMetadata, totalStatesAvailable, stateTypes, otherResults)
}
private fun <R> Query<R>.resultStream(paging: PageSpecification): Stream<R> {
return if (paging.isDefault) {
val allResults = resultList
// final pagination check (fail-fast on too many results when no pagination specified)
if (!skipPagingChecks && paging.isDefault && results.size > DEFAULT_PAGE_SIZE) {
throw VaultQueryException("There are ${results.size} results, which exceeds the limit of $DEFAULT_PAGE_SIZE for queries that do not specify paging. In order to retrieve these results, provide a `PageSpecification(pageNumber, pageSize)` to the method invoked.")
checkVaultQuery(allResults.size != paging.pageSize + 1) {
"There are more results than the limit of $DEFAULT_PAGE_SIZE for queries that do not specify paging. " +
"In order to retrieve these results, provide a PageSpecification to the method invoked."
}
val statesAndRefs: MutableList<StateAndRef<T>> = mutableListOf()
val statesMeta: MutableList<Vault.StateMetadata> = mutableListOf()
val otherResults: MutableList<Any> = mutableListOf()
val stateRefs = mutableSetOf<StateRef>()
results.asSequence()
.forEachIndexed { index, result ->
if (result[0] is VaultSchemaV1.VaultStates) {
if (!paging.isDefault && index == paging.pageSize) // skip last result if paged
return@forEachIndexed
val vaultState = result[0] as VaultSchemaV1.VaultStates
val stateRef = StateRef(SecureHash.create(vaultState.stateRef!!.txId), vaultState.stateRef!!.index)
stateRefs.add(stateRef)
statesMeta.add(Vault.StateMetadata(stateRef,
vaultState.contractStateClassName,
vaultState.recordedTime,
vaultState.consumedTime,
vaultState.stateStatus,
vaultState.notary,
vaultState.lockId,
vaultState.lockUpdateTime,
vaultState.relevancyStatus,
constraintInfo(vaultState.constraintType, vaultState.constraintData)
))
} else {
// TODO: improve typing of returned other results
log.debug { "OtherResults: ${Arrays.toString(result.toArray())}" }
otherResults.addAll(result.toArray().asList())
}
}
if (stateRefs.isNotEmpty())
statesAndRefs.addAll(uncheckedCast(servicesForResolution.loadStates(stateRefs)))
Vault.Page(states = statesAndRefs, statesMetadata = statesMeta, stateTypes = criteriaParser.stateTypes, totalStatesAvailable = totalStates, otherResults = otherResults)
allResults.stream()
} else {
stream()
}
}
private fun Query<*>.setResultWindow(paging: PageSpecification) {
if (paging.isDefault) {
// For both SQLServer and PostgresSQL, firstResult must be >= 0.
firstResult = 0
// Peek ahead and see if there are more results in case pagination should be done
maxResults = paging.pageSize + 1
} else {
firstResult = (paging.pageNumber - 1) * paging.pageSize
maxResults = paging.pageSize
}
}
private fun <T : ContractState> queryTotalStateCount(baseCriteria: QueryCriteria, contractStateType: Class<out T>): Long {
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val criteria = baseCriteria.and(countCriteria)
val (query) = createQuery(criteria, contractStateType, null)
val results = query.resultList
return results.last().toArray().last() as Long
}
private fun <T : ContractState> createQuery(criteria: QueryCriteria,
contractStateType: Class<out T>,
sorting: Sort?): Pair<Query<Tuple>, Vault.StateStatus> {
val criteriaQuery = criteriaBuilder.createQuery(Tuple::class.java)
val criteriaParser = HibernateQueryCriteriaParser(
contractStateType,
contractStateTypeMappings,
criteriaBuilder,
criteriaQuery,
criteriaQuery.from(VaultSchemaV1.VaultStates::class.java)
)
criteriaParser.parse(criteria, sorting)
val query = getSession().createQuery(criteriaQuery)
return Pair(query, criteriaParser.stateTypes)
}
/**
* Returns a [DataFeed] containing the results of the provided query, along with the associated observable, containing any subsequent updates.
*
@ -775,6 +820,12 @@ class NodeVaultService(
}
}
private inline fun checkVaultQuery(value: Boolean, lazyMessage: () -> Any) {
if (!value) {
throw VaultQueryException(lazyMessage().toString())
}
}
private fun <T : ContractState> filterContractStates(update: Vault.Update<T>, contractStateType: Class<out T>) =
update.copy(consumed = filterByContractState(contractStateType, update.consumed),
produced = filterByContractState(contractStateType, update.produced))
@ -802,6 +853,7 @@ class NodeVaultService(
}
private fun getSession() = database.currentOrNew().session
/**
* Derive list from existing vault states and then incrementally update using vault observables
*/

View File

@ -2,7 +2,9 @@ package net.corda.node.services.vault
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.MAX_ISSUER_REF_SIZE
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.UniqueIdentifier
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.toStringShort
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party
@ -192,3 +194,19 @@ object VaultSchemaV1 : MappedSchema(
) : IndirectStatePersistable<PersistentStateRefAndKey>
}
fun PersistentStateRef.toStateRef(): StateRef = StateRef(SecureHash.create(txId), index)
fun VaultSchemaV1.VaultStates.toStateMetadata(): Vault.StateMetadata {
return Vault.StateMetadata(
stateRef!!.toStateRef(),
contractStateClassName,
recordedTime,
consumedTime,
stateStatus,
notary,
lockId,
lockUpdateTime,
relevancyStatus,
Vault.ConstraintInfo.constraintInfo(constraintType, constraintData)
)
}

View File

@ -21,6 +21,7 @@ import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.node.services.vault.toStateRef
import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import java.security.PublicKey
@ -41,6 +42,8 @@ class BFTSmartNotaryService(
) : NotaryService() {
companion object {
private val log = contextLogger()
@Suppress("unused") // Used by NotaryLoader via reflection
@JvmStatic
val serializationFilter
get() = { clazz: Class<*> ->
@ -147,12 +150,7 @@ class BFTSmartNotaryService(
toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) },
fromPersistentEntity = {
//TODO null check will become obsolete after making DB/JPA columns not nullable
val txId = it.id.txId
val index = it.id.index
Pair(
StateRef(txhash = SecureHash.create(txId), index = index),
SecureHash.create(it.consumingTxHash)
)
Pair(it.id.toStateRef(), SecureHash.create(it.consumingTxHash))
},
toPersistentEntity = { (txHash, index): StateRef, id: SecureHash ->
CommittedState(

View File

@ -24,6 +24,7 @@ import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.serialize
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.node.services.vault.toStateRef
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.notary.common.InternalResult
@ -142,10 +143,6 @@ class JPAUniquenessProvider(
fun encodeStateRef(s: StateRef): PersistentStateRef {
return PersistentStateRef(s.txhash.toString(), s.index)
}
fun decodeStateRef(s: PersistentStateRef): StateRef {
return StateRef(txhash = SecureHash.create(s.txId), index = s.index)
}
}
/**
@ -215,15 +212,15 @@ class JPAUniquenessProvider(
committedStates.addAll(existing)
}
return committedStates.map {
val stateRef = StateRef(txhash = SecureHash.create(it.id.txId), index = it.id.index)
return committedStates.associate {
val stateRef = it.id.toStateRef()
val consumingTxId = SecureHash.create(it.consumingTxHash)
if (stateRef in references) {
stateRef to StateConsumptionDetails(consumingTxId.reHash(), type = StateConsumptionDetails.ConsumedStateType.REFERENCE_INPUT_STATE)
} else {
stateRef to StateConsumptionDetails(consumingTxId.reHash())
}
}.toMap()
}
}
private fun<T> withRetry(block: () -> T): T {

View File

@ -0,0 +1 @@
net.corda.node.services.messaging.NodeOpenSSLContextFactory

View File

@ -0,0 +1 @@
net.corda.node.services.messaging.NodeSSLContextFactory

View File

@ -124,7 +124,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
bobNode.internals.disableDBCloseOnStop()
bobNode.database.transaction {
VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, cashIssuer)
VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, cashIssuer, atMostThisManyStates = 10)
}
val alicesFakePaper = aliceNode.database.transaction {
@ -233,7 +233,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
val issuer = bank.ref(1, 2, 3)
bobNode.database.transaction {
VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, issuer)
VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, issuer, atMostThisManyStates = 10)
}
val alicesFakePaper = aliceNode.database.transaction {
fillUpForSeller(false, issuer, alice,

View File

@ -28,12 +28,14 @@ import net.corda.finance.schemas.CashSchemaV1
import net.corda.finance.test.SampleCashSchemaV1
import net.corda.finance.test.SampleCashSchemaV2
import net.corda.finance.test.SampleCashSchemaV3
import net.corda.node.internal.NodeServicesForResolution
import net.corda.node.services.api.WritableTransactionStorage
import net.corda.node.services.schema.ContractStateAndRef
import net.corda.node.services.schema.NodeSchemaService
import net.corda.node.services.schema.PersistentStateService
import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSchemaV1
import net.corda.node.services.vault.toStateRef
import net.corda.node.testing.DummyFungibleContract
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
@ -48,7 +50,6 @@ import net.corda.testing.internal.vault.VaultFiller
import net.corda.testing.node.MockServices
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import org.assertj.core.api.Assertions
import org.assertj.core.api.Assertions.`in`
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.hibernate.SessionFactory
@ -122,7 +123,14 @@ class HibernateConfigurationTest {
services = object : MockServices(cordappPackages, BOB_NAME, mock<IdentityService>().also {
doReturn(null).whenever(it).verifyAndRegisterIdentity(argThat { name == BOB_NAME })
}, generateKeyPair(), dummyNotary.keyPair) {
override val vaultService = NodeVaultService(Clock.systemUTC(), keyManagementService, servicesForResolution, database, schemaService, cordappClassloader).apply { start() }
override val vaultService = NodeVaultService(
Clock.systemUTC(),
keyManagementService,
servicesForResolution as NodeServicesForResolution,
database,
schemaService,
cordappClassloader
).apply { start() }
override fun recordTransactions(statesToRecord: StatesToRecord, txs: Iterable<SignedTransaction>) {
for (stx in txs) {
(validatedTransactions as WritableTransactionStorage).addTransaction(stx)
@ -183,7 +191,7 @@ class HibernateConfigurationTest {
// execute query
val queryResults = entityManager.createQuery(criteriaQuery).resultList
val coins = queryResults.map {
services.loadState(toStateRef(it.stateRef!!)).data
services.loadState(it.stateRef!!.toStateRef()).data
}.sumCash()
assertThat(coins.toDecimal() >= BigDecimal("50.00"))
}
@ -739,7 +747,7 @@ class HibernateConfigurationTest {
val queryResults = entityManager.createQuery(criteriaQuery).resultList
queryResults.forEach {
val cashState = services.loadState(toStateRef(it.stateRef!!)).data as Cash.State
val cashState = services.loadState(it.stateRef!!.toStateRef()).data as Cash.State
println("${it.stateRef} with owner: ${cashState.owner.owningKey.toBase58String()}")
}
@ -823,7 +831,7 @@ class HibernateConfigurationTest {
// execute query
val queryResults = entityManager.createQuery(criteriaQuery).resultList
queryResults.forEach {
val cashState = services.loadState(toStateRef(it.stateRef!!)).data as Cash.State
val cashState = services.loadState(it.stateRef!!.toStateRef()).data as Cash.State
println("${it.stateRef} with owner ${cashState.owner.owningKey.toBase58String()} and participants ${cashState.participants.map { it.owningKey.toBase58String() }}")
}
@ -961,10 +969,6 @@ class HibernateConfigurationTest {
}
}
private fun toStateRef(pStateRef: PersistentStateRef): StateRef {
return StateRef(SecureHash.create(pStateRef.txId), pStateRef.index)
}
@Test(timeout=300_000)
fun `schema change`() {
fun createNewDB(schemas: Set<MappedSchema>, initialiseSchema: Boolean = true): CordaPersistence {

View File

@ -244,7 +244,6 @@ class FlowSoftLocksTests {
100.DOLLARS,
bankNode.services,
thisManyStates,
thisManyStates,
cashIssuer
)
}

View File

@ -20,14 +20,13 @@ import net.corda.finance.*
import net.corda.finance.contracts.CommercialPaper
import net.corda.finance.contracts.Commodity
import net.corda.finance.contracts.DealState
import net.corda.finance.workflows.asset.selection.AbstractCashSelection
import net.corda.finance.contracts.asset.Cash
import net.corda.finance.schemas.CashSchemaV1
import net.corda.finance.schemas.CashSchemaV1.PersistentCashState
import net.corda.finance.schemas.CommercialPaperSchemaV1
import net.corda.finance.test.SampleCashSchemaV2
import net.corda.finance.test.SampleCashSchemaV3
import net.corda.finance.workflows.CommercialPaperUtils
import net.corda.finance.workflows.asset.selection.AbstractCashSelection
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.DatabaseTransaction
@ -197,8 +196,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
}
protected fun consumeCash(amount: Amount<Currency>) = vaultFiller.consumeCash(amount, CHARLIE)
private fun setUpDb(_database: CordaPersistence, delay: Long = 0) {
_database.transaction {
private fun setUpDb(database: CordaPersistence, delay: Long = 0) {
database.transaction {
// create new states
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 10, DUMMY_CASH_ISSUER)
val linearStatesXYZ = vaultFiller.fillWithSomeTestLinearStates(1, "XYZ")
@ -444,7 +444,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.state.data.linearNumber }.sortedBy { it.ref.txhash }.sortedBy { it.ref.index }).isEqualTo(allStates)
}
(1..3).forEach {
repeat(3) {
val newAllStates = vaultService.queryBy<DummyLinearContract.State>(sorting = sorting, criteria = criteria).states
assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates)
assertThat(newAllStates).containsExactlyElementsOf(allStates)
@ -485,7 +485,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.ref.txhash }.sortedByDescending { it.ref.index }).isEqualTo(allStates)
}
(1..3).forEach {
repeat(3) {
val newAllStates = vaultService.queryBy<DummyLinearContract.State>(sorting = sorting, criteria = criteria).states
assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates)
assertThat(newAllStates).containsExactlyElementsOf(allStates)
@ -638,7 +638,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
}
val sorted = results.states.sortedBy { it.ref.toString() }
assertThat(results.states).isEqualTo(sorted)
assertThat(results.states).allSatisfy { !consumed.contains(it.ref.txhash) }
assertThat(results.states).allSatisfy { assertThat(consumed).doesNotContain(it.ref.txhash) }
}
}
@ -1537,7 +1537,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789"))
// count fungible assets
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count)
val countCriteria = VaultCustomQueryCriteria(count)
val fungibleStateCount = vaultService.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
assertThat(fungibleStateCount).isEqualTo(10L)
@ -1563,7 +1563,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
// count fungible assets
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val countCriteria = VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val fungibleStateCount = vaultService.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
assertThat(fungibleStateCount).isEqualTo(10L)
@ -1583,7 +1583,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// UNCONSUMED states (default)
// count fungible assets
val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED)
val countCriteriaUnconsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED)
val fungibleStateCountUnconsumed = vaultService.queryBy<FungibleAsset<*>>(countCriteriaUnconsumed).otherResults.single() as Long
assertThat(fungibleStateCountUnconsumed.toInt()).isEqualTo(10 - cashUpdates.consumed.size + cashUpdates.produced.size)
@ -1598,7 +1598,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// CONSUMED states
// count fungible assets
val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED)
val countCriteriaConsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED)
val fungibleStateCountConsumed = vaultService.queryBy<FungibleAsset<*>>(countCriteriaConsumed).otherResults.single() as Long
assertThat(fungibleStateCountConsumed.toInt()).isEqualTo(cashUpdates.consumed.size)
@ -1622,7 +1622,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val start = TODAY
val end = TODAY.plus(30, ChronoUnit.DAYS)
val recordedBetweenExpression = TimeCondition(
QueryCriteria.TimeInstantType.RECORDED,
TimeInstantType.RECORDED,
ColumnPredicate.Between(start, end))
val criteria = VaultQueryCriteria(timeCondition = recordedBetweenExpression)
val results = vaultService.queryBy<ContractState>(criteria)
@ -1632,7 +1632,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// Future
val startFuture = TODAY.plus(1, ChronoUnit.DAYS)
val recordedBetweenExpressionFuture = TimeCondition(
QueryCriteria.TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end))
TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end))
val criteriaFuture = VaultQueryCriteria(timeCondition = recordedBetweenExpressionFuture)
assertThat(vaultService.queryBy<ContractState>(criteriaFuture).states).isEmpty()
}
@ -1648,7 +1648,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
consumeCash(100.DOLLARS)
val asOfDateTime = TODAY
val consumedAfterExpression = TimeCondition(
QueryCriteria.TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime))
TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime))
val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED,
timeCondition = consumedAfterExpression)
val results = vaultService.queryBy<ContractState>(criteria)
@ -1674,7 +1674,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// pagination: last page
@Test(timeout=300_000)
fun `all states with paging specification - last`() {
fun `all states with paging specification - last`() {
database.transaction {
vaultFiller.fillWithSomeTestCash(95.DOLLARS, notaryServices, 95, DUMMY_CASH_ISSUER)
// Last page implies we need to perform a row count for the Query first,
@ -1705,6 +1705,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
}
// pagination: invalid page size
@Suppress("INTEGER_OVERFLOW")
@Test(timeout=300_000)
fun `invalid page size`() {
expectedEx.expect(VaultQueryException::class.java)
@ -1712,8 +1713,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
database.transaction {
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 100, DUMMY_CASH_ISSUER)
@Suppress("EXPECTED_CONDITION")
val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, @Suppress("INTEGER_OVERFLOW") Integer.MAX_VALUE + 1) // overflow = -2147483648
val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, Integer.MAX_VALUE + 1) // overflow = -2147483648
val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL)
vaultService.queryBy<ContractState>(criteria, paging = pagingSpec)
}
@ -1723,7 +1723,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
@Test(timeout=300_000)
fun `pagination not specified but more than default results available`() {
expectedEx.expect(VaultQueryException::class.java)
expectedEx.expectMessage("provide a `PageSpecification(pageNumber, pageSize)`")
expectedEx.expectMessage("provide a PageSpecification")
database.transaction {
vaultFiller.fillWithSomeTestCash(201.DOLLARS, notaryServices, 201, DUMMY_CASH_ISSUER)
@ -1781,9 +1781,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
println("$index : $any")
}
assertThat(results.otherResults.size).isEqualTo(402)
val instants = results.otherResults.filter { it is Instant }.map { it as Instant }
val instants = results.otherResults.filterIsInstance<Instant>()
assertThat(instants).isSorted
val longs = results.otherResults.filter { it is Long }.map { it as Long }
val longs = results.otherResults.filterIsInstance<Long>()
assertThat(longs.size).isEqualTo(201)
assertThat(instants.size).isEqualTo(201)
assertThat(longs.sum()).isEqualTo(20100L)
@ -1911,8 +1911,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
fun `LinearStateQueryCriteria returns empty resultset without errors if there is an empty list after the 'in' clause`() {
database.transaction {
val uid = UniqueIdentifier("999")
vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, uniqueIdentifier = uid)
vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, externalId = "1234")
vaultFiller.fillWithSomeTestLinearStates(txCount = 1, uniqueIdentifier = uid)
vaultFiller.fillWithSomeTestLinearStates(txCount = 1, externalId = "1234")
val uuidCriteria = LinearStateQueryCriteria(uuid = listOf(uid.id))
val externalIdCriteria = LinearStateQueryCriteria(externalId = listOf("1234"))
@ -2061,6 +2061,26 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
}
}
@Test(timeout = 300_000)
fun `unconsumed states which are globally unordered across multiple transactions sorted by custom attribute`() {
val linearNumbers = Array(2) { LongArray(2) }
// Make sure states from the same transaction are not given consecutive linear numbers.
linearNumbers[0][0] = 1L
linearNumbers[0][1] = 3L
linearNumbers[1][0] = 2L
linearNumbers[1][1] = 4L
val results = database.transaction {
vaultFiller.fillWithTestStates(txCount = 2, statesPerTx = 2) { participantsToUse, txIndex, stateIndex ->
DummyLinearContract.State(participants = participantsToUse, linearNumber = linearNumbers[txIndex][stateIndex])
}
val sortColumn = Sort.SortColumn(SortAttribute.Custom(DummyLinearStateSchemaV1.PersistentDummyLinearState::class.java, "linearNumber"))
vaultService.queryBy<DummyLinearContract.State>(VaultQueryCriteria(), sorting = Sort(setOf(sortColumn)))
}
assertThat(results.states.map { it.state.data.linearNumber }).isEqualTo(listOf(1L, 2L, 3L, 4L))
}
@Test(timeout=300_000)
fun `return consumed linear states for a given linear id`() {
database.transaction {
@ -2390,7 +2410,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
services.recordTransactions(commercialPaper2)
val ccyIndex = builder { CommercialPaperSchemaV1.PersistentCommercialPaperState::currency.equal(USD.currencyCode) }
val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex)
val criteria1 = VaultCustomQueryCriteria(ccyIndex)
val result = vaultService.queryBy<CommercialPaper.State>(criteria1)
@ -2433,9 +2453,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val maturityIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::maturity.greaterThanOrEqual(TEST_TX_TIME + 30.days)
val faceValueIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::faceValue.greaterThanOrEqual(10000L)
val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex)
val criteria2 = QueryCriteria.VaultCustomQueryCriteria(maturityIndex)
val criteria3 = QueryCriteria.VaultCustomQueryCriteria(faceValueIndex)
val criteria1 = VaultCustomQueryCriteria(ccyIndex)
val criteria2 = VaultCustomQueryCriteria(maturityIndex)
val criteria3 = VaultCustomQueryCriteria(faceValueIndex)
vaultService.queryBy<CommercialPaper.State>(criteria1.and(criteria3).and(criteria2))
}
@ -2458,8 +2478,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val generalCriteria = VaultQueryCriteria(Vault.StateStatus.ALL)
val results = builder {
val currencyIndex = PersistentCashState::currency.equal(USD.currencyCode)
val quantityIndex = PersistentCashState::pennies.greaterThanOrEqual(10L)
val currencyIndex = CashSchemaV1.PersistentCashState::currency.equal(USD.currencyCode)
val quantityIndex = CashSchemaV1.PersistentCashState::pennies.greaterThanOrEqual(10L)
val customCriteria1 = VaultCustomQueryCriteria(currencyIndex)
val customCriteria2 = VaultCustomQueryCriteria(quantityIndex)
@ -2710,7 +2730,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// Enrich and override QueryCriteria with additional default attributes (such as soft locks)
val enrichedCriteria = VaultQueryCriteria(contractStateTypes = setOf(DealState::class.java), // enrich
softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())),
softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())),
status = Vault.StateStatus.UNCONSUMED) // override
// Sorting
val sortAttribute = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF)
@ -3056,7 +3076,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
assertThat(snapshot.states).hasSize(0)
val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states
this.session.flush()
vaultFiller.consumeLinearStates(states.toList())
vaultFiller.consumeStates(states)
updates
}
@ -3079,7 +3099,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
assertThat(snapshot.states).hasSize(0)
val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states
this.session.flush()
vaultFiller.consumeLinearStates(states.toList())
vaultFiller.consumeStates(states)
updates
}
@ -3102,7 +3122,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
assertThat(snapshot.states).hasSize(0)
val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states
this.session.flush()
vaultFiller.consumeLinearStates(states.toList())
vaultFiller.consumeStates(states)
updates
}

View File

@ -10,7 +10,6 @@ import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.AbstractParty
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.uncheckedCast
import net.corda.core.node.ServicesForResolution
import net.corda.core.node.services.KeyManagementService
import net.corda.core.node.services.queryBy
import net.corda.core.node.services.vault.QueryCriteria.SoftLockingCondition
@ -29,6 +28,7 @@ import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.testing.core.singleIdentity
import net.corda.testing.flows.registerCoreFlowFactory
import net.corda.coretesting.internal.rigorousMock
import net.corda.node.internal.NodeServicesForResolution
import net.corda.testing.node.internal.InternalMockNetwork
import net.corda.testing.node.internal.enclosedCordapp
import net.corda.testing.node.internal.startFlow
@ -86,7 +86,7 @@ class VaultSoftLockManagerTest {
private val mockNet = InternalMockNetwork(cordappsForAllNodes = listOf(enclosedCordapp()), defaultFactory = { args ->
object : InternalMockNetwork.MockNode(args) {
override fun makeVaultService(keyManagementService: KeyManagementService,
services: ServicesForResolution,
services: NodeServicesForResolution,
database: CordaPersistence,
cordappLoader: CordappLoader): VaultServiceInternal {
val node = this

View File

@ -1,30 +1,50 @@
@file:Suppress("UNUSED_PARAMETER")
@file:JvmName("TestUtils")
@file:Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod", "LongParameterList")
package net.corda.testing.core
import net.corda.core.contracts.PartyAndReference
import net.corda.core.contracts.StateRef
import net.corda.core.crypto.*
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignatureScheme
import net.corda.core.crypto.toStringShort
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.toX500Name
import net.corda.core.internal.unspecifiedCountry
import net.corda.core.node.NodeInfo
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.millis
import net.corda.core.utilities.minutes
import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames
import org.bouncycastle.asn1.x509.CRLReason
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.ExtensionsGenerator
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.IssuingDistributionPoint
import org.bouncycastle.cert.jcajce.JcaX509CRLConverter
import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import java.math.BigInteger
import java.net.URI
import java.security.KeyPair
import java.security.PublicKey
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.time.Duration
import java.time.Instant
import java.util.*
import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.fail
@ -109,6 +129,44 @@ fun getTestPartyAndCertificate(name: CordaX500Name, publicKey: PublicKey): Party
return getTestPartyAndCertificate(Party(name, publicKey))
}
fun createCRL(issuer: CertificateAndKeyPair,
revokedCerts: List<X509Certificate>,
issuingDistPoint: URI? = null,
thisUpdate: Instant = Instant.now(),
nextUpdate: Instant = thisUpdate + 5.minutes,
indirect: Boolean = false,
revocationDate: Instant = thisUpdate,
crlReason: Int = CRLReason.keyCompromise,
signatureAlgorithm: String = "SHA256withECDSA"): X509CRL {
val builder = JcaX509v2CRLBuilder(issuer.certificate.subjectX500Principal, Date.from(thisUpdate))
val extensionUtils = JcaX509ExtensionUtils()
builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(issuer.certificate))
// This is required and needs to match the certificate settings with respect to being indirect
builder.addExtension(
Extension.issuingDistributionPoint,
true,
IssuingDistributionPoint(
issuingDistPoint?.let { DistributionPointName(toGeneralNames(it.toString(), GeneralName.uniformResourceIdentifier)) },
indirect,
false
)
)
builder.setNextUpdate(Date.from(nextUpdate))
for (revokedCert in revokedCerts) {
val extensionsGenerator = ExtensionsGenerator()
extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(crlReason))
// Certificate issuer is required for indirect CRL
extensionsGenerator.addExtension(
Extension.certificateIssuer,
true,
GeneralNames(GeneralName(revokedCert.issuerX500Principal.toX500Name()))
)
builder.addCRLEntry(revokedCert.serialNumber, Date.from(revocationDate), extensionsGenerator.generate())
}
val bcProvider = Crypto.findProvider("BC")
val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(bcProvider).build(issuer.keyPair.private)
return JcaX509CRLConverter().setProvider(bcProvider).getCRL(builder.build(signer))
}
private val count = AtomicInteger(0)
/**
@ -188,7 +246,6 @@ fun NodeInfo.singleIdentity(): Party = singleIdentityAndCert().party
* The above will test our expectation that the getWaitingFlows action was executed successfully considering
* that it may take a few hundreds of milliseconds for the flow state machine states to settle.
*/
@Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod")
fun <T> executeTest(
timeout: Duration,
cleanup: (() -> Unit)? = null,

View File

@ -28,6 +28,7 @@ import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.node.VersionInfo
import net.corda.node.internal.ServicesForResolutionImpl
import net.corda.node.internal.NodeServicesForResolution
import net.corda.node.internal.cordapp.JarScanningCordappLoader
import net.corda.node.services.api.*
import net.corda.node.services.diagnostics.NodeDiagnosticsService
@ -463,7 +464,14 @@ open class MockServices private constructor(
get() = ServicesForResolutionImpl(identityService, attachments, cordappProvider, networkParametersService, validatedTransactions)
internal fun makeVaultService(schemaService: SchemaService, database: CordaPersistence, cordappLoader: CordappLoader): VaultServiceInternal {
return NodeVaultService(clock, keyManagementService, servicesForResolution, database, schemaService, cordappLoader.appClassLoader).apply { start() }
return NodeVaultService(
clock,
keyManagementService,
servicesForResolution as NodeServicesForResolution,
database,
schemaService,
cordappLoader.appClassLoader
).apply { start() }
}
// This needs to be internal as MutableClassToInstanceMap is a guava type and shouldn't be part of our public API

View File

@ -4,30 +4,26 @@ package net.corda.testing.node.internal.network
import net.corda.core.crypto.Crypto
import net.corda.core.internal.CertRole
import net.corda.core.internal.toX500Name
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.days
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.ContentSignerBuilder
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames
import net.corda.nodeapi.internal.crypto.certificateType
import net.corda.nodeapi.internal.crypto.toJca
import org.bouncycastle.asn1.x500.X500Name
import net.corda.testing.core.createCRL
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.DistributionPoint
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.IssuingDistributionPoint
import org.bouncycastle.asn1.x509.ReasonFlags
import org.bouncycastle.cert.jcajce.JcaX509CRLConverter
import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.ServerConnector
import org.eclipse.jetty.server.handler.HandlerCollection
@ -36,11 +32,12 @@ import org.eclipse.jetty.servlet.ServletHolder
import org.glassfish.jersey.server.ResourceConfig
import org.glassfish.jersey.servlet.ServletContainer
import java.io.Closeable
import java.math.BigInteger
import java.net.InetSocketAddress
import java.net.URI
import java.security.KeyPair
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.time.Duration
import java.util.*
import javax.security.auth.x500.X500Principal
import javax.ws.rs.GET
@ -51,7 +48,7 @@ import kotlin.collections.ArrayList
class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
companion object {
private const val SIGNATURE_ALGORITHM = "SHA256withECDSA"
private val logger = contextLogger()
const val NODE_CRL = "node.crl"
const val FORBIDDEN_CRL = "forbidden.crl"
@ -72,8 +69,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
null
)
if (crlDistPoint != null) {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint)))
val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(X500Name.getInstance(it.encoded))) }
val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier))
val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(it.toX500Name())) }
val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames)
builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint)))
}
@ -87,14 +84,17 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
}
}
val revokedNodeCerts: MutableList<BigInteger> = ArrayList()
val revokedIntermediateCerts: MutableList<BigInteger> = ArrayList()
val revokedNodeCerts: MutableList<X509Certificate> = ArrayList()
val revokedIntermediateCerts: MutableList<X509Certificate> = ArrayList()
val rootCa: CertificateAndKeyPair = DEV_ROOT_CA
private lateinit var _intermediateCa: CertificateAndKeyPair
val intermediateCa: CertificateAndKeyPair get() = _intermediateCa
@Volatile
var delay: Duration? = null
val hostAndPort: NetworkHostAndPort
get() = server.connectors.mapNotNull { it as? ServerConnector }
.map { NetworkHostAndPort(it.host, it.localPort) }
@ -106,7 +106,7 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
DEV_INTERMEDIATE_CA.certificate.withCrlDistPoint(rootCa.keyPair, "http://$hostAndPort/crl/$INTERMEDIATE_CRL"),
DEV_INTERMEDIATE_CA.keyPair
)
println("Network management web services started on $hostAndPort")
logger.info("Network management web services started on $hostAndPort")
}
fun replaceNodeCertDistPoint(nodeCaCert: X509Certificate,
@ -115,29 +115,20 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
return nodeCaCert.withCrlDistPoint(intermediateCa.keyPair, nodeCaCrlDistPoint, crlIssuer)
}
fun createRevocationList(signatureAlgorithm: String,
ca: CertificateAndKeyPair,
endpoint: String,
indirect: Boolean,
serialNumbers: List<BigInteger>): X509CRL {
println("Generating CRL for $endpoint")
val builder = JcaX509v2CRLBuilder(ca.certificate.subjectX500Principal, Date(System.currentTimeMillis() - 1.minutes.toMillis()))
val extensionUtils = JcaX509ExtensionUtils()
builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(ca.certificate))
val issuingDistPointName = GeneralName(GeneralName.uniformResourceIdentifier, "http://$hostAndPort/crl/$endpoint")
// This is required and needs to match the certificate settings with respect to being indirect
val issuingDistPoint = IssuingDistributionPoint(DistributionPointName(GeneralNames(issuingDistPointName)), indirect, false)
builder.addExtension(Extension.issuingDistributionPoint, true, issuingDistPoint)
builder.setNextUpdate(Date(System.currentTimeMillis() + 1.seconds.toMillis()))
serialNumbers.forEach {
builder.addCRLEntry(it, Date(System.currentTimeMillis() - 10.minutes.toMillis()), ReasonFlags.certificateHold)
}
val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(Crypto.findProvider("BC")).build(ca.keyPair.private)
return JcaX509CRLConverter().setProvider(Crypto.findProvider("BC")).getCRL(builder.build(signer))
private fun createServerCRL(issuer: CertificateAndKeyPair,
endpoint: String,
indirect: Boolean,
revokedCerts: List<X509Certificate>): X509CRL {
logger.info("Generating CRL for /$endpoint: ${revokedCerts.map { it.serialNumber }}")
return createCRL(
issuer,
revokedCerts,
issuingDistPoint = URI("http://$hostAndPort/crl/$endpoint"),
indirect = indirect
)
}
override fun close() {
println("Shutting down network management web services...")
server.stop()
server.join()
}
@ -159,8 +150,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
@Path(NODE_CRL)
@Produces("application/pkcs7-crl")
fun getNodeCRL(): Response {
return Response.ok(crlServer.createRevocationList(
SIGNATURE_ALGORITHM,
crlServer.delay?.toMillis()?.let(Thread::sleep)
return Response.ok(crlServer.createServerCRL(
crlServer.intermediateCa,
NODE_CRL,
false,
@ -179,8 +170,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
@Path(INTERMEDIATE_CRL)
@Produces("application/pkcs7-crl")
fun getIntermediateCRL(): Response {
return Response.ok(crlServer.createRevocationList(
SIGNATURE_ALGORITHM,
crlServer.delay?.toMillis()?.let(Thread::sleep)
return Response.ok(crlServer.createServerCRL(
crlServer.rootCa,
INTERMEDIATE_CRL,
false,
@ -192,11 +183,11 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
@Path(EMPTY_CRL)
@Produces("application/pkcs7-crl")
fun getEmptyCRL(): Response {
return Response.ok(crlServer.createRevocationList(
SIGNATURE_ALGORITHM,
return Response.ok(crlServer.createServerCRL(
crlServer.rootCa,
EMPTY_CRL,
true, emptyList()
true,
emptyList()
).encoded).build()
}
}

View File

@ -42,6 +42,7 @@ import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.SchemaMigration
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.serialization.internal.amqp.AMQP_ENABLED
import net.corda.testing.core.ALICE_NAME
@ -52,6 +53,8 @@ import java.io.IOException
import java.net.ServerSocket
import java.nio.file.Path
import java.security.KeyPair
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.util.*
import java.util.jar.JarOutputStream
import java.util.jar.Manifest
@ -147,6 +150,12 @@ fun p2pSslOptions(path: Path, name: CordaX500Name = CordaX500Name("MegaCorp", "L
return sslConfig
}
fun fixedCrlSource(crls: Set<X509CRL>): CrlSource {
return object : CrlSource {
override fun fetch(certificate: X509Certificate): Set<X509CRL> = crls
}
}
/** This is the same as the deprecated [WireTransaction] c'tor but avoids the deprecation warning. */
@SuppressWarnings("LongParameterList")
fun createWireTransaction(inputs: List<StateRef>,

View File

@ -1,6 +1,20 @@
@file:Suppress("LongParameterList")
package net.corda.testing.internal.vault
import net.corda.core.contracts.*
import net.corda.core.contracts.Amount
import net.corda.core.contracts.AttachmentConstraint
import net.corda.core.contracts.AutomaticPlaceholderConstraint
import net.corda.core.contracts.BelongsToContract
import net.corda.core.contracts.CommandAndState
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.FungibleAsset
import net.corda.core.contracts.Issued
import net.corda.core.contracts.LinearState
import net.corda.core.contracts.PartyAndReference
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.TransactionState
import net.corda.core.contracts.UniqueIdentifier
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SignatureMetadata
import net.corda.core.identity.AbstractParty
@ -19,9 +33,7 @@ import net.corda.finance.contracts.asset.Cash
import net.corda.finance.contracts.asset.Obligation
import net.corda.finance.contracts.asset.OnLedgerAsset
import net.corda.finance.workflows.asset.CashUtils
import net.corda.testing.contracts.DummyContract
import net.corda.testing.contracts.DummyState
import net.corda.testing.core.DummyCommandData
import net.corda.testing.core.TestIdentity
import net.corda.testing.core.dummyCommand
import net.corda.testing.core.singleIdentity
@ -32,6 +44,7 @@ import java.time.Duration
import java.time.Instant
import java.time.Instant.now
import java.util.*
import kotlin.math.floor
/**
* The service hub should provide at least a key management service and a storage service.
@ -46,7 +59,7 @@ class VaultFiller @JvmOverloads constructor(
private val rngFactory: () -> Random = { Random(0L) }) {
companion object {
fun calculateRandomlySizedAmounts(howMuch: Amount<Currency>, min: Int, max: Int, rng: Random): LongArray {
val numSlots = min + Math.floor(rng.nextDouble() * (max - min)).toInt()
val numSlots = min + floor(rng.nextDouble() * (max - min)).toInt()
val baseSize = howMuch.quantity / numSlots
check(baseSize > 0) { baseSize }
@ -79,31 +92,18 @@ class VaultFiller @JvmOverloads constructor(
issuerServices: ServiceHub = services,
participants: List<AbstractParty> = emptyList(),
includeMe: Boolean = true): Vault<DealState> {
val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey
val me = AnonymousParty(myKey)
val participantsToUse = if (includeMe) participants.plus(me) else participants
val transactions: List<SignedTransaction> = dealIds.map {
// Issue a deal state
val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply {
addOutputState(DummyDealContract.State(ref = it, participants = participantsToUse), DUMMY_DEAL_PROGRAM_ID)
addCommand(dummyCommand())
}
val stx = issuerServices.signInitialTransaction(dummyIssue)
return@map services.addSignature(stx, defaultNotary.publicKey)
return fillWithTestStates(
txCount = dealIds.size,
participants = participants,
includeMe = includeMe,
services = issuerServices
) { participantsToUse, txIndex, _ ->
DummyDealContract.State(ref = dealIds[txIndex], participants = participantsToUse)
}
val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE
services.recordTransactions(statesToRecord, transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<DealState>(i) }
}
return Vault(states)
}
@JvmOverloads
fun fillWithSomeTestLinearStates(numberToCreate: Int,
fun fillWithSomeTestLinearStates(txCount: Int,
externalId: String? = null,
participants: List<AbstractParty> = emptyList(),
uniqueIdentifier: UniqueIdentifier? = null,
@ -113,81 +113,41 @@ class VaultFiller @JvmOverloads constructor(
linearTimestamp: Instant = now(),
constraint: AttachmentConstraint = AutomaticPlaceholderConstraint,
includeMe: Boolean = true): Vault<LinearState> {
val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey
val me = AnonymousParty(myKey)
val issuerKey = defaultNotary.keyPair
val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID)
val participantsToUse = if (includeMe) participants.plus(me) else participants
val transactions: List<SignedTransaction> = (1..numberToCreate).map {
// Issue a Linear state
val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply {
addOutputState(DummyLinearContract.State(
linearId = uniqueIdentifier ?: UniqueIdentifier(externalId),
participants = participantsToUse,
linearString = linearString,
linearNumber = linearNumber,
linearBoolean = linearBoolean,
linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID,
constraint = constraint)
addCommand(dummyCommand())
}
return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata)
return fillWithTestStates(txCount, 1, participants, constraint, includeMe) { participantsToUse, _, _ ->
DummyLinearContract.State(
linearId = uniqueIdentifier ?: UniqueIdentifier(externalId),
participants = participantsToUse,
linearString = linearString,
linearNumber = linearNumber,
linearBoolean = linearBoolean,
linearTimestamp = linearTimestamp
)
}
val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE
services.recordTransactions(statesToRecord, transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<LinearState>(i) }
}
return Vault(states)
}
@JvmOverloads
fun fillWithSomeTestLinearAndDealStates(numberToCreate: Int,
fun fillWithSomeTestLinearAndDealStates(txCount: Int,
externalId: String? = null,
participants: List<AbstractParty> = emptyList(),
linearString: String = "",
linearNumber: Long = 0L,
linearBoolean: Boolean = false,
linearTimestamp: Instant = now()): Vault<LinearState> {
val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey
val me = AnonymousParty(myKey)
val issuerKey = defaultNotary.keyPair
val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID)
val transactions: List<SignedTransaction> = (1..numberToCreate).map {
val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply {
// Issue a Linear state
addOutputState(DummyLinearContract.State(
linearTimestamp: Instant = now()): Vault<ContractState> {
return fillWithTestStates(txCount, 2, participants) { participantsToUse, _, stateIndex ->
when (stateIndex) {
0 -> DummyLinearContract.State(
linearId = UniqueIdentifier(externalId),
participants = participants.plus(me),
participants = participantsToUse,
linearString = linearString,
linearNumber = linearNumber,
linearBoolean = linearBoolean,
linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID)
// Issue a Deal state
addOutputState(DummyDealContract.State(ref = "test ref", participants = participants.plus(me)), DUMMY_DEAL_PROGRAM_ID)
addCommand(dummyCommand())
linearTimestamp = linearTimestamp
)
else -> DummyDealContract.State(ref = "test ref", participants = participantsToUse)
}
return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata)
}
services.recordTransactions(transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<LinearState>(i) }
}
return Vault(states)
}
@JvmOverloads
fun fillWithSomeTestCash(howMuch: Amount<Currency>,
issuerServices: ServiceHub,
thisManyStates: Int,
issuedBy: PartyAndReference,
owner: AbstractParty? = null,
rng: Random? = null,
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT) = fillWithSomeTestCash(howMuch, issuerServices, thisManyStates, thisManyStates, issuedBy, owner, rng, statesToRecord)
/**
* Creates a random set of between (by default) 3 and 10 cash states that add up to the given amount and adds them
* to the vault. This is intended for unit tests. By default the cash is owned by the legal
@ -196,14 +156,15 @@ class VaultFiller @JvmOverloads constructor(
* @param issuerServices service hub of the issuer node, which will be used to sign the transaction.
* @return a vault object that represents the generated states (it will NOT be the full vault from the service hub!).
*/
@JvmOverloads
fun fillWithSomeTestCash(howMuch: Amount<Currency>,
issuerServices: ServiceHub,
atLeastThisManyStates: Int,
atMostThisManyStates: Int,
issuedBy: PartyAndReference,
owner: AbstractParty? = null,
rng: Random? = null,
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault<Cash.State> {
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT,
atMostThisManyStates: Int = atLeastThisManyStates): Vault<Cash.State> {
val amounts = calculateRandomlySizedAmounts(howMuch, atLeastThisManyStates, atMostThisManyStates, rng ?: rngFactory())
// We will allocate one state to one transaction, for simplicities sake.
val cash = Cash()
@ -212,39 +173,46 @@ class VaultFiller @JvmOverloads constructor(
cash.generateIssue(issuance, Amount(pennies, Issued(issuedBy, howMuch.token)), owner ?: services.myInfo.singleIdentity(), altNotary)
return@map issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey)
}
services.recordTransactions(statesToRecord, transactions)
// Get all the StateRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<Cash.State>(i) }
}
return Vault(states)
return recordTransactions(transactions, statesToRecord)
}
/**
* Records a dummy state in the Vault (useful for creating random states when testing vault queries)
*/
fun fillWithDummyState(participants: List<AbstractParty> = listOf(services.myInfo.singleIdentity())) : Vault<DummyState> {
val outputState = TransactionState(
data = DummyState(Random().nextInt(), participants = participants),
contract = DummyContract.PROGRAM_ID,
notary = defaultNotary.party
)
val participantKeys : List<PublicKey> = participants.map { it.owningKey }
val builder = TransactionBuilder()
.addOutputState(outputState)
.addCommand(DummyCommandData, participantKeys)
val stxn = services.signInitialTransaction(builder)
services.recordTransactions(stxn)
return Vault(setOf(stxn.tx.outRef(0)))
fun fillWithDummyState(participants: List<AbstractParty> = listOf(services.myInfo.singleIdentity())): Vault<DummyState> {
return fillWithTestStates(participants = participants) { participantsToUse, _, _ ->
DummyState(Random().nextInt(), participants = participantsToUse)
}
}
/**
* Puts together an issuance transaction for the specified amount that starts out being owned by the given pubkey.
*/
fun generateCommoditiesIssue(tx: TransactionBuilder, amount: Amount<Issued<Commodity>>, owner: AbstractParty, notary: Party)
= OnLedgerAsset.generateIssue(tx, TransactionState(CommodityState(amount, owner), Obligation.PROGRAM_ID, notary), Obligation.Commands.Issue())
fun <T : ContractState> fillWithTestStates(txCount: Int = 1,
statesPerTx: Int = 1,
participants: List<AbstractParty> = emptyList(),
constraint: AttachmentConstraint = AutomaticPlaceholderConstraint,
includeMe: Boolean = true,
services: ServiceHub = this.services,
genOutputState: (participantsToUse: List<AbstractParty>, txIndex: Int, stateIndex: Int) -> T): Vault<T> {
val issuerKey = defaultNotary.keyPair
val signatureMetadata = SignatureMetadata(
services.myInfo.platformVersion,
Crypto.findSignatureScheme(issuerKey.public).schemeNumberID
)
val participantsToUse = if (includeMe) {
participants + AnonymousParty(this.services.myInfo.chooseIdentity().owningKey)
} else {
participants
}
val transactions = Array(txCount) { txIndex ->
val builder = TransactionBuilder(notary = defaultNotary.party)
repeat(statesPerTx) { stateIndex ->
builder.addOutputState(genOutputState(participantsToUse, txIndex, stateIndex), constraint)
}
builder.addCommand(dummyCommand())
services.signInitialTransaction(builder).withAdditionalSignature(issuerKey, signatureMetadata)
}
val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE
return recordTransactions(transactions.asList(), statesToRecord)
}
/**
*
@ -257,13 +225,16 @@ class VaultFiller @JvmOverloads constructor(
val me = AnonymousParty(myKey)
val issuance = TransactionBuilder(null as Party?)
generateCommoditiesIssue(issuance, Amount(amount.quantity, Issued(issuedBy, amount.token)), me, altNotary)
OnLedgerAsset.generateIssue(
issuance,
TransactionState(CommodityState(Amount(amount.quantity, Issued(issuedBy, amount.token)), me), Obligation.PROGRAM_ID, altNotary),
Obligation.Commands.Issue()
)
val transaction = issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey)
services.recordTransactions(transaction)
return Vault(setOf(transaction.tx.outRef(0)))
return recordTransactions(listOf(transaction))
}
private fun <T : LinearState> consume(states: List<StateAndRef<T>>) {
fun consumeStates(states: Iterable<StateAndRef<*>>) {
// Create a txn consuming different contract types
states.forEach {
val builder = TransactionBuilder(notary = altNotary).apply {
@ -300,10 +271,11 @@ class VaultFiller @JvmOverloads constructor(
}
}
fun consumeDeals(dealStates: List<StateAndRef<DealState>>) = consume(dealStates)
fun consumeLinearStates(linearStates: List<StateAndRef<LinearState>>) = consume(linearStates)
fun consumeDeals(dealStates: List<StateAndRef<DealState>>) = consumeStates(dealStates)
fun consumeLinearStates(linearStates: List<StateAndRef<LinearState>>) = consumeStates(linearStates)
fun evolveLinearStates(linearStates: List<StateAndRef<LinearState>>) = consumeAndProduce(linearStates)
fun evolveLinearState(linearState: StateAndRef<LinearState>): StateAndRef<LinearState> = consumeAndProduce(linearState)
/**
* Consume cash, sending any change to the default identity for this node. Only suitable for use in test scenarios,
* where nodes have a default identity.
@ -319,6 +291,16 @@ class VaultFiller @JvmOverloads constructor(
services.recordTransactions(spendTx)
return update.getOrThrow(Duration.ofSeconds(3))
}
private fun <T : ContractState> recordTransactions(transactions: Iterable<SignedTransaction>,
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault<T> {
services.recordTransactions(statesToRecord, transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<T>(i) }
}
return Vault(states)
}
}
@ -344,4 +326,3 @@ data class CommodityState(
override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Obligation.Commands.Move(), copy(owner = newOwner))
}