Merge branch 'release/os/4.9' into shams-4.10-merge-e6a80822

# Conflicts:
#	.github/workflows/check-pr-title.yml
#	.snyk
#	node-api/src/main/kotlin/net/corda/nodeapi/internal/protonwrapper/netty/AMQPClient.kt
#	node/src/integration-test/kotlin/net/corda/node/amqp/AMQPClientSslErrorsTest.kt
#	node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
This commit is contained in:
Shams Asari 2023-07-13 10:53:30 +01:00
commit 3a6deeefa7
67 changed files with 1687 additions and 1204 deletions

View File

@ -9,6 +9,6 @@ jobs:
steps: steps:
- uses: morrisoncole/pr-lint-action@v1.6.1 - uses: morrisoncole/pr-lint-action@v1.6.1
with: with:
title-regex: '^((CORDA|AG|EG|ENT|INFRA|NAAS|ES)-\d+|NOTICK)(.*)' title-regex: '^((CORDA|AG|EG|ENT|INFRA|ES)-\d+|NOTICK)(.*)'
on-failed-regex-comment: "PR title failed to match regex -> `%regex%`" on-failed-regex-comment: "PR title failed to match regex -> `%regex%`"
repo-token: "${{ secrets.GITHUB_TOKEN }}" repo-token: "${{ secrets.GITHUB_TOKEN }}"

View File

@ -1,5 +1,6 @@
package net.corda.client.rpc package net.corda.client.rpc
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.ReconnectingCordaRPCOps import net.corda.client.rpc.internal.ReconnectingCordaRPCOps
import net.corda.client.rpc.internal.SerializationEnvironmentHelper import net.corda.client.rpc.internal.SerializationEnvironmentHelper
@ -52,7 +53,7 @@ class CordaRPCConnection private constructor(
sslConfiguration: ClientRpcSslOptions? = null, sslConfiguration: ClientRpcSslOptions? = null,
classLoader: ClassLoader? = null classLoader: ClassLoader? = null
): CordaRPCConnection { ): CordaRPCConnection {
val observersPool: ExecutorService = Executors.newCachedThreadPool() val observersPool: ExecutorService = Executors.newCachedThreadPool(DefaultThreadFactory("RPCObserver"))
return CordaRPCConnection(null, observersPool, ReconnectingCordaRPCOps( return CordaRPCConnection(null, observersPool, ReconnectingCordaRPCOps(
addresses, addresses,
username, username,

View File

@ -1,5 +1,6 @@
package net.corda.client.rpc.internal package net.corda.client.rpc.internal
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.client.rpc.ConnectionFailureException import net.corda.client.rpc.ConnectionFailureException
import net.corda.client.rpc.CordaRPCClient import net.corda.client.rpc.CordaRPCClient
import net.corda.client.rpc.CordaRPCClientConfiguration import net.corda.client.rpc.CordaRPCClientConfiguration
@ -99,7 +100,8 @@ class ReconnectingCordaRPCOps private constructor(
ErrorInterceptingHandler(reconnectingRPCConnection)) as CordaRPCOps ErrorInterceptingHandler(reconnectingRPCConnection)) as CordaRPCOps
} }
} }
private val retryFlowsPool = Executors.newScheduledThreadPool(1) private val retryFlowsPool = Executors.newScheduledThreadPool(1, DefaultThreadFactory("FlowRetry"))
/** /**
* This function runs a flow and retries until it completes successfully. * This function runs a flow and retries until it completes successfully.
* *

View File

@ -0,0 +1,29 @@
package net.corda.coretests.crypto.internal
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.testing.core.createCRL
import org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import org.junit.Test
class ProviderMapTest {
// https://github.com/corda/corda/pull/3997
@Test(timeout = 300_000)
fun `verify CRL algorithms`() {
val crl = createCRL(
issuer = DEV_ROOT_CA,
revokedCerts = emptyList(),
signatureAlgorithm = "SHA256withECDSA"
)
// This should pass.
crl.verify(DEV_ROOT_CA.keyPair.public)
// Try changing the algorithm to EC will fail.
assertThatIllegalArgumentException().isThrownBy {
createCRL(
issuer = DEV_ROOT_CA,
revokedCerts = emptyList(),
signatureAlgorithm = "EC"
)
}.withMessage("Unknown signature type requested: EC")
}
}

View File

@ -65,7 +65,7 @@ interface ServicesForResolution {
/** /**
* Given a [Set] of [StateRef]'s loads the referenced transaction and looks up the specified output [ContractState]. * Given a [Set] of [StateRef]'s loads the referenced transaction and looks up the specified output [ContractState].
* *
* @throws TransactionResolutionException if [stateRef] points to a non-existent transaction. * @throws TransactionResolutionException if any of the [stateRefs] point to a non-existent transaction.
*/ */
// TODO: future implementation to use a Vault state ref -> contract state BLOB table and perform single query bulk load // TODO: future implementation to use a Vault state ref -> contract state BLOB table and perform single query bulk load
// as the existing transaction store will become encrypted at some point // as the existing transaction store will become encrypted at some point

View File

@ -1,3 +1,5 @@
@file:Suppress("LongParameterList")
package net.corda.core.node.services package net.corda.core.node.services
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
@ -197,8 +199,7 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
* 4) Status types used in this query: [StateStatus.UNCONSUMED], [StateStatus.CONSUMED], [StateStatus.ALL]. * 4) Status types used in this query: [StateStatus.UNCONSUMED], [StateStatus.CONSUMED], [StateStatus.ALL].
* 5) Other results as a [List] of any type (eg. aggregate function results with/without group by). * 5) Other results as a [List] of any type (eg. aggregate function results with/without group by).
* *
* Note: currently otherResults are used only for Aggregate Functions (in which case, the states and statesMetadata * Note: currently [otherResults] is used only for aggregate functions (in which case, [states] and [statesMetadata] will be empty).
* results will be empty).
*/ */
@CordaSerializable @CordaSerializable
data class Page<out T : ContractState>(val states: List<StateAndRef<T>>, data class Page<out T : ContractState>(val states: List<StateAndRef<T>>,
@ -213,11 +214,11 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
val contractStateClassName: String, val contractStateClassName: String,
val recordedTime: Instant, val recordedTime: Instant,
val consumedTime: Instant?, val consumedTime: Instant?,
val status: Vault.StateStatus, val status: StateStatus,
val notary: AbstractParty?, val notary: AbstractParty?,
val lockId: String?, val lockId: String?,
val lockUpdateTime: Instant?, val lockUpdateTime: Instant?,
val relevancyStatus: Vault.RelevancyStatus? = null, val relevancyStatus: RelevancyStatus? = null,
val constraintInfo: ConstraintInfo? = null val constraintInfo: ConstraintInfo? = null
) { ) {
fun copy( fun copy(
@ -225,7 +226,7 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
contractStateClassName: String = this.contractStateClassName, contractStateClassName: String = this.contractStateClassName,
recordedTime: Instant = this.recordedTime, recordedTime: Instant = this.recordedTime,
consumedTime: Instant? = this.consumedTime, consumedTime: Instant? = this.consumedTime,
status: Vault.StateStatus = this.status, status: StateStatus = this.status,
notary: AbstractParty? = this.notary, notary: AbstractParty? = this.notary,
lockId: String? = this.lockId, lockId: String? = this.lockId,
lockUpdateTime: Instant? = this.lockUpdateTime lockUpdateTime: Instant? = this.lockUpdateTime
@ -237,11 +238,11 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
contractStateClassName: String = this.contractStateClassName, contractStateClassName: String = this.contractStateClassName,
recordedTime: Instant = this.recordedTime, recordedTime: Instant = this.recordedTime,
consumedTime: Instant? = this.consumedTime, consumedTime: Instant? = this.consumedTime,
status: Vault.StateStatus = this.status, status: StateStatus = this.status,
notary: AbstractParty? = this.notary, notary: AbstractParty? = this.notary,
lockId: String? = this.lockId, lockId: String? = this.lockId,
lockUpdateTime: Instant? = this.lockUpdateTime, lockUpdateTime: Instant? = this.lockUpdateTime,
relevancyStatus: Vault.RelevancyStatus? relevancyStatus: RelevancyStatus?
): StateMetadata { ): StateMetadata {
return StateMetadata(ref, contractStateClassName, recordedTime, consumedTime, status, notary, lockId, lockUpdateTime, relevancyStatus, ConstraintInfo(AlwaysAcceptAttachmentConstraint)) return StateMetadata(ref, contractStateClassName, recordedTime, consumedTime, status, notary, lockId, lockUpdateTime, relevancyStatus, ConstraintInfo(AlwaysAcceptAttachmentConstraint))
} }
@ -249,9 +250,9 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
companion object { companion object {
@Deprecated("No longer used. The vault does not emit empty updates") @Deprecated("No longer used. The vault does not emit empty updates")
val NoUpdate = Update(emptySet(), emptySet(), type = Vault.UpdateType.GENERAL, references = emptySet()) val NoUpdate = Update(emptySet(), emptySet(), type = UpdateType.GENERAL, references = emptySet())
@Deprecated("No longer used. The vault does not emit empty updates") @Deprecated("No longer used. The vault does not emit empty updates")
val NoNotaryUpdate = Vault.Update(emptySet(), emptySet(), type = Vault.UpdateType.NOTARY_CHANGE, references = emptySet()) val NoNotaryUpdate = Update(emptySet(), emptySet(), type = UpdateType.NOTARY_CHANGE, references = emptySet())
} }
} }
@ -302,7 +303,7 @@ interface VaultService {
fun whenConsumed(ref: StateRef): CordaFuture<Vault.Update<ContractState>> { fun whenConsumed(ref: StateRef): CordaFuture<Vault.Update<ContractState>> {
val query = QueryCriteria.VaultQueryCriteria( val query = QueryCriteria.VaultQueryCriteria(
stateRefs = listOf(ref), stateRefs = listOf(ref),
status = Vault.StateStatus.CONSUMED status = StateStatus.CONSUMED
) )
val result = trackBy<ContractState>(query) val result = trackBy<ContractState>(query)
val snapshot = result.snapshot.states val snapshot = result.snapshot.states
@ -358,8 +359,8 @@ interface VaultService {
/** /**
* Helper function to determine spendable states and soft locking them. * Helper function to determine spendable states and soft locking them.
* Currently performance will be worse than for the hand optimised version in * Currently performance will be worse than for the hand optimised version in
* [Cash.unconsumedCashStatesForSpending]. However, this is fully generic and can operate with custom [FungibleState] * [net.corda.finance.workflows.asset.selection.AbstractCashSelection.unconsumedCashStatesForSpending]. However, this is fully generic
* and [FungibleAsset] states. * and can operate with custom [FungibleState] and [FungibleAsset] states.
* @param lockId The [FlowLogic.runId]'s [UUID] of the current flow used to soft lock the states. * @param lockId The [FlowLogic.runId]'s [UUID] of the current flow used to soft lock the states.
* @param eligibleStatesQuery A custom query object that selects down to the appropriate subset of all states of the * @param eligibleStatesQuery A custom query object that selects down to the appropriate subset of all states of the
* [contractStateType]. e.g. by selecting on account, issuer, etc. The query is internally augmented with the * [contractStateType]. e.g. by selecting on account, issuer, etc. The query is internally augmented with the

View File

@ -21,14 +21,29 @@ import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.core.utilities.hours import net.corda.core.utilities.hours
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme import net.corda.coretesting.internal.NettyTestClient
import net.corda.coretesting.internal.NettyTestHandler
import net.corda.coretesting.internal.NettyTestServer
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.createDevNodeCa import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_IDENTITY_SIGNATURE_SCHEME import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_IDENTITY_SIGNATURE_SCHEME
import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME
import net.corda.nodeapi.internal.crypto.checkValidity
import net.corda.nodeapi.internal.crypto.getSupportedKey
import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore
import net.corda.nodeapi.internal.crypto.save
import net.corda.nodeapi.internal.crypto.toBc
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.crypto.x509Certificates
import net.corda.nodeapi.internal.installDevNodeCaCertPath import net.corda.nodeapi.internal.installDevNodeCaCertPath
import net.corda.nodeapi.internal.protonwrapper.netty.init import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
import net.corda.serialization.internal.AllWhitelist import net.corda.serialization.internal.AllWhitelist
import net.corda.serialization.internal.SerializationContextImpl import net.corda.serialization.internal.SerializationContextImpl
import net.corda.serialization.internal.SerializationFactoryImpl import net.corda.serialization.internal.SerializationFactoryImpl
@ -37,25 +52,16 @@ import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.TestIdentity import net.corda.testing.core.TestIdentity
import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.coretesting.internal.NettyTestClient
import net.corda.coretesting.internal.NettyTestHandler
import net.corda.coretesting.internal.NettyTestServer
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.checkValidity
import net.corda.nodeapi.internal.crypto.getSupportedKey
import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore
import net.corda.nodeapi.internal.crypto.save
import net.corda.nodeapi.internal.crypto.toBc
import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.crypto.x509Certificates
import net.corda.testing.internal.IS_OPENJ9 import net.corda.testing.internal.IS_OPENJ9
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.i2p.crypto.eddsa.EdDSAPrivateKey import net.i2p.crypto.eddsa.EdDSAPrivateKey
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.asn1.x509.* import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.bouncycastle.jcajce.provider.asymmetric.edec.BCEdDSAPrivateKey import org.bouncycastle.jcajce.provider.asymmetric.edec.BCEdDSAPrivateKey
import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey
import org.junit.Assume import org.junit.Assume
@ -74,10 +80,19 @@ import java.security.PrivateKey
import java.security.cert.CertPath import java.security.cert.CertPath
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.* import java.util.*
import javax.net.ssl.* import javax.net.ssl.SSLContext
import javax.net.ssl.SSLParameters
import javax.net.ssl.SSLServerSocket
import javax.net.ssl.SSLSocket
import javax.security.auth.x500.X500Principal import javax.security.auth.x500.X500Principal
import kotlin.concurrent.thread import kotlin.concurrent.thread
import kotlin.test.* import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue
import kotlin.test.fail
class X509UtilitiesTest { class X509UtilitiesTest {
private companion object { private companion object {
@ -295,15 +310,10 @@ class X509UtilitiesTest {
sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate) sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val context = SSLContext.getInstance("TLS") val context = SSLContext.getInstance("TLS")
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
keyManagerFactory.init(keyStore)
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get())
trustMgrFactory.init(trustStore)
val trustManagers = trustMgrFactory.trustManagers val trustManagers = trustMgrFactory.trustManagers
context.init(keyManagers, trustManagers, newSecureRandom()) context.init(keyManagers, trustManagers, newSecureRandom())
@ -388,15 +398,8 @@ class X509UtilitiesTest {
sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa) sslConfig.keyStore.get(true).registerDevP2pCertificates(MEGA_CORP.name, rootCa.certificate, intermediateCa)
sslConfig.createTrustStore(rootCa.certificate) sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get() val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val trustStore = sslConfig.trustStore.get() val trustManagerFactory = trustManagerFactory(sslConfig.trustStore.get())
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore)
val sslServerContext = SslContextBuilder val sslServerContext = SslContextBuilder
.forServer(keyManagerFactory) .forServer(keyManagerFactory)

View File

@ -42,8 +42,8 @@ class ArtemisMessagingClient(private val config: MutualSslConfiguration,
override fun start(): Started = synchronized(this) { override fun start(): Started = synchronized(this) {
check(started == null) { "start can't be called twice" } check(started == null) { "start can't be called twice" }
val tcpTransport = p2pConnectorTcpTransport(serverAddress, config, threadPoolName = threadPoolName, trace = trace) val tcpTransport = p2pConnectorTcpTransport(serverAddress, config, threadPoolName = threadPoolName, trace = trace)
val backupTransports = backupServerAddressPool.map { val backupTransports = backupServerAddressPool.mapIndexed { index, address ->
p2pConnectorTcpTransport(it, config, threadPoolName = threadPoolName, trace = trace) p2pConnectorTcpTransport(address, config, threadPoolName = "$threadPoolName-backup${index+1}", trace = trace)
} }
log.info("Connecting to message broker: $serverAddress") log.info("Connecting to message broker: $serverAddress")

View File

@ -1,16 +1,18 @@
@file:Suppress("LongParameterList")
package net.corda.nodeapi.internal package net.corda.nodeapi.internal
import net.corda.core.messaging.ClientRpcSslOptions import net.corda.core.messaging.ClientRpcSslOptions
import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.BrokerRpcSslOptions import net.corda.nodeapi.BrokerRpcSslOptions
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT
import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.config.SslConfiguration import net.corda.nodeapi.internal.config.SslConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.apache.activemq.artemis.api.core.TransportConfiguration import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants
import java.nio.file.Path import javax.net.ssl.TrustManagerFactory
@Suppress("LongParameterList") @Suppress("LongParameterList")
class ArtemisTcpTransport { class ArtemisTcpTransport {
@ -23,6 +25,7 @@ class ArtemisTcpTransport {
val TLS_VERSIONS = listOf("TLSv1.2") val TLS_VERSIONS = listOf("TLSv1.2")
const val SSL_HANDSHAKE_TIMEOUT_NAME = "Corda-SSLHandshakeTimeout" const val SSL_HANDSHAKE_TIMEOUT_NAME = "Corda-SSLHandshakeTimeout"
const val TRUST_MANAGER_FACTORY_NAME = "Corda-TrustManagerFactory"
const val TRACE_NAME = "Corda-Trace" const val TRACE_NAME = "Corda-Trace"
const val THREAD_POOL_NAME_NAME = "Corda-ThreadPoolName" const val THREAD_POOL_NAME_NAME = "Corda-ThreadPoolName"
@ -30,7 +33,6 @@ class ArtemisTcpTransport {
// Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop. // Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop.
// It does not use AMQP messages for its own messages e.g. topology and heartbeats. // It does not use AMQP messages for its own messages e.g. topology and heartbeats.
private const val P2P_PROTOCOLS = "CORE,AMQP" private const val P2P_PROTOCOLS = "CORE,AMQP"
private const val RPC_PROTOCOLS = "CORE" private const val RPC_PROTOCOLS = "CORE"
private fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort, protocols: String) = mapOf( private fun defaultArtemisOptions(hostAndPort: NetworkHostAndPort, protocols: String) = mapOf(
@ -39,46 +41,35 @@ class ArtemisTcpTransport {
TransportConstants.PORT_PROP_NAME to hostAndPort.port, TransportConstants.PORT_PROP_NAME to hostAndPort.port,
TransportConstants.PROTOCOLS_PROP_NAME to protocols, TransportConstants.PROTOCOLS_PROP_NAME to protocols,
TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null), TransportConstants.USE_GLOBAL_WORKER_POOL_PROP_NAME to (nodeSerializationEnv != null),
TransportConstants.REMOTING_THREADS_PROPNAME to (if (nodeSerializationEnv != null) -1 else 1),
// turn off direct delivery in Artemis - this is latency optimisation that can lead to // turn off direct delivery in Artemis - this is latency optimisation that can lead to
//hick-ups under high load (CORDA-1336) //hick-ups under high load (CORDA-1336)
TransportConstants.DIRECT_DELIVER to false) TransportConstants.DIRECT_DELIVER to false)
private val defaultSSLOptions = mapOf(
TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","),
TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to TLS_VERSIONS.joinToString(","))
private fun SslConfiguration.addToTransportOptions(options: MutableMap<String, Any>) { private fun SslConfiguration.addToTransportOptions(options: MutableMap<String, Any>) {
if (keyStore != null || trustStore != null) {
options[TransportConstants.SSL_ENABLED_PROP_NAME] = true
options[TransportConstants.NEED_CLIENT_AUTH_PROP_NAME] = true
}
keyStore?.let { keyStore?.let {
with (it) { with (it) {
path.requireOnDefaultFileSystem() path.requireOnDefaultFileSystem()
options.putAll(get().toKeyStoreTransportOptions(path)) options[TransportConstants.KEYSTORE_TYPE_PROP_NAME] = "JKS"
options[TransportConstants.KEYSTORE_PATH_PROP_NAME] = path
options[TransportConstants.KEYSTORE_PASSWORD_PROP_NAME] = get().password
} }
} }
trustStore?.let { trustStore?.let {
with (it) { with (it) {
path.requireOnDefaultFileSystem() path.requireOnDefaultFileSystem()
options.putAll(get().toTrustStoreTransportOptions(path)) options[TransportConstants.TRUSTSTORE_TYPE_PROP_NAME] = "JKS"
options[TransportConstants.TRUSTSTORE_PATH_PROP_NAME] = path
options[TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME] = get().password
} }
} }
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
options[SSL_HANDSHAKE_TIMEOUT_NAME] = handshakeTimeout ?: DEFAULT_SSL_HANDSHAKE_TIMEOUT options[SSL_HANDSHAKE_TIMEOUT_NAME] = handshakeTimeout ?: DEFAULT_SSL_HANDSHAKE_TIMEOUT
} }
private fun CertificateStore.toKeyStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.KEYSTORE_TYPE_PROP_NAME to "JKS",
TransportConstants.KEYSTORE_PATH_PROP_NAME to path,
TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun CertificateStore.toTrustStoreTransportOptions(path: Path) = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to "JKS",
TransportConstants.TRUSTSTORE_PATH_PROP_NAME to path,
TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME to password,
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true)
private fun ClientRpcSslOptions.toTransportOptions() = mapOf( private fun ClientRpcSslOptions.toTransportOptions() = mapOf(
TransportConstants.SSL_ENABLED_PROP_NAME to true, TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to trustStoreProvider, TransportConstants.TRUSTSTORE_TYPE_PROP_NAME to trustStoreProvider,
@ -94,76 +85,110 @@ class ArtemisTcpTransport {
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: MutualSslConfiguration?, config: MutualSslConfiguration?,
trustManagerFactory: TrustManagerFactory? = config?.trustStore?.get()?.let(::trustManagerFactory),
enableSSL: Boolean = true, enableSSL: Boolean = true,
threadPoolName: String = "P2PServer", threadPoolName: String = "P2PServer",
trace: Boolean = false): TransportConfiguration { trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>() val options = mutableMapOf<String, Any>()
if (enableSSL) { if (enableSSL) {
config?.addToTransportOptions(options) config?.addToTransportOptions(options)
} }
return createAcceptorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace) return createAcceptorTransport(
hostAndPort,
P2P_PROTOCOLS,
options,
trustManagerFactory,
enableSSL,
threadPoolName,
trace,
remotingThreads
)
} }
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
config: MutualSslConfiguration?, config: MutualSslConfiguration?,
enableSSL: Boolean = true, enableSSL: Boolean = true,
threadPoolName: String = "P2PClient", threadPoolName: String = "P2PClient",
trace: Boolean = false): TransportConfiguration { trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>() val options = mutableMapOf<String, Any>()
if (enableSSL) { if (enableSSL) {
config?.addToTransportOptions(options) config?.addToTransportOptions(options)
} }
return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace) return createConnectorTransport(hostAndPort, P2P_PROTOCOLS, options, enableSSL, threadPoolName, trace, remotingThreads)
} }
fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: BrokerRpcSslOptions?, config: BrokerRpcSslOptions?,
enableSSL: Boolean = true, enableSSL: Boolean = true,
trace: Boolean = false): TransportConfiguration { threadPoolName: String = "RPCServer",
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>() val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) { if (config != null && enableSSL) {
config.keyStorePath.requireOnDefaultFileSystem() config.keyStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions()) options.putAll(config.toTransportOptions())
} }
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCServer", trace) return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, null, enableSSL, threadPoolName, trace, remotingThreads)
} }
fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort, fun rpcConnectorTcpTransport(hostAndPort: NetworkHostAndPort,
config: ClientRpcSslOptions?, config: ClientRpcSslOptions?,
enableSSL: Boolean = true, enableSSL: Boolean = true,
trace: Boolean = false): TransportConfiguration { trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>() val options = mutableMapOf<String, Any>()
if (config != null && enableSSL) { if (config != null && enableSSL) {
config.trustStorePath.requireOnDefaultFileSystem() config.trustStorePath.requireOnDefaultFileSystem()
options.putAll(config.toTransportOptions()) options.putAll(config.toTransportOptions())
} }
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace) return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, enableSSL, "RPCClient", trace, remotingThreads)
} }
fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort, fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort,
config: SslConfiguration, config: SslConfiguration,
threadPoolName: String = "Internal-RPCClient",
trace: Boolean = false): TransportConfiguration { trace: Boolean = false): TransportConfiguration {
val options = mutableMapOf<String, Any>() val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options) config.addToTransportOptions(options)
return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCClient", trace) return createConnectorTransport(hostAndPort, RPC_PROTOCOLS, options, true, threadPoolName, trace, null)
} }
fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort,
config: SslConfiguration, config: SslConfiguration,
trace: Boolean = false): TransportConfiguration { threadPoolName: String = "Internal-RPCServer",
trace: Boolean = false,
remotingThreads: Int? = null): TransportConfiguration {
val options = mutableMapOf<String, Any>() val options = mutableMapOf<String, Any>()
config.addToTransportOptions(options) config.addToTransportOptions(options)
return createAcceptorTransport(hostAndPort, RPC_PROTOCOLS, options, true, "Internal-RPCServer", trace) return createAcceptorTransport(
hostAndPort,
RPC_PROTOCOLS,
options,
trustManagerFactory(requireNotNull(config.trustStore).get()),
true,
threadPoolName,
trace,
remotingThreads
)
} }
private fun createAcceptorTransport(hostAndPort: NetworkHostAndPort, private fun createAcceptorTransport(hostAndPort: NetworkHostAndPort,
protocols: String, protocols: String,
options: MutableMap<String, Any>, options: MutableMap<String, Any>,
trustManagerFactory: TrustManagerFactory?,
enableSSL: Boolean, enableSSL: Boolean,
threadPoolName: String, threadPoolName: String,
trace: Boolean): TransportConfiguration { trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
// Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections // Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections
options[TransportConstants.HANDSHAKE_TIMEOUT] = 0 options[TransportConstants.HANDSHAKE_TIMEOUT] = 0
if (trustManagerFactory != null) {
// NettyAcceptor only creates default TrustManagerFactorys with the provided trust store details. However, we need to use
// more customised instances which use our revocation checkers, so we pass them in, to be picked up by Node(Open)SSLContextFactory.
options[TRUST_MANAGER_FACTORY_NAME] = trustManagerFactory
}
return createTransport( return createTransport(
"net.corda.node.services.messaging.NodeNettyAcceptorFactory", "net.corda.node.services.messaging.NodeNettyAcceptorFactory",
hostAndPort, hostAndPort,
@ -171,7 +196,8 @@ class ArtemisTcpTransport {
options, options,
enableSSL, enableSSL,
threadPoolName, threadPoolName,
trace trace,
remotingThreads
) )
} }
@ -180,15 +206,21 @@ class ArtemisTcpTransport {
options: MutableMap<String, Any>, options: MutableMap<String, Any>,
enableSSL: Boolean, enableSSL: Boolean,
threadPoolName: String, threadPoolName: String,
trace: Boolean): TransportConfiguration { trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
if (enableSSL) {
// This is required to stop Client checking URL address vs. Server provided certificate
options[TransportConstants.VERIFY_HOST_PROP_NAME] = false
}
return createTransport( return createTransport(
NodeNettyConnectorFactory::class.java.name, CordaNettyConnectorFactory::class.java.name,
hostAndPort, hostAndPort,
protocols, protocols,
options, options,
enableSSL, enableSSL,
threadPoolName, threadPoolName,
trace trace,
remotingThreads
) )
} }
@ -198,13 +230,15 @@ class ArtemisTcpTransport {
options: MutableMap<String, Any>, options: MutableMap<String, Any>,
enableSSL: Boolean, enableSSL: Boolean,
threadPoolName: String, threadPoolName: String,
trace: Boolean): TransportConfiguration { trace: Boolean,
remotingThreads: Int?): TransportConfiguration {
options += defaultArtemisOptions(hostAndPort, protocols) options += defaultArtemisOptions(hostAndPort, protocols)
if (enableSSL) { if (enableSSL) {
options += defaultSSLOptions options[TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME] = CIPHER_SUITES.joinToString(",")
// This is required to stop Client checking URL address vs. Server provided certificate options[TransportConstants.ENABLED_PROTOCOLS_PROP_NAME] = TLS_VERSIONS.joinToString(",")
options[TransportConstants.VERIFY_HOST_PROP_NAME] = false
} }
// By default, use only one remoting thread in tests (https://github.com/corda/corda/pull/2357)
options[TransportConstants.REMOTING_THREADS_PROPNAME] = remotingThreads ?: if (nodeSerializationEnv == null) 1 else -1
options[THREAD_POOL_NAME_NAME] = threadPoolName options[THREAD_POOL_NAME_NAME] = threadPoolName
options[TRACE_NAME] = trace options[TRACE_NAME] = trace
return TransportConfiguration(className, options) return TransportConfiguration(className, options)

View File

@ -1,8 +1,14 @@
@file:JvmName("ArtemisUtils") @file:JvmName("ArtemisUtils")
package net.corda.nodeapi.internal package net.corda.nodeapi.internal
import net.corda.core.internal.declaredField
import org.apache.activemq.artemis.utils.actors.ProcessorBase
import java.nio.file.FileSystems import java.nio.file.FileSystems
import java.nio.file.Path import java.nio.file.Path
import java.util.concurrent.Executor
import java.util.concurrent.ThreadFactory
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.atomic.AtomicInteger
/** /**
* Require that the [Path] is on a default file system, and therefore is one that Artemis is willing to use. * Require that the [Path] is on a default file system, and therefore is one that Artemis is willing to use.
@ -16,3 +22,29 @@ fun requireMessageSize(messageSize: Int, limit: Int) {
require(messageSize <= limit) { "Message exceeds maxMessageSize network parameter, maxMessageSize: [$limit], message size: [$messageSize]" } require(messageSize <= limit) { "Message exceeds maxMessageSize network parameter, maxMessageSize: [$limit], message size: [$messageSize]" }
} }
val Executor.rootExecutor: Executor get() {
var executor: Executor = this
while (executor is ProcessorBase<*>) {
executor = executor.declaredField<Executor>("delegate").value
}
return executor
}
fun Executor.setThreadPoolName(threadPoolName: String) {
(rootExecutor as? ThreadPoolExecutor)?.let { it.threadFactory = NamedThreadFactory(threadPoolName, it.threadFactory) }
}
private class NamedThreadFactory(poolName: String, private val delegate: ThreadFactory) : ThreadFactory {
companion object {
private val poolId = AtomicInteger(0)
}
private val prefix = "$poolName-${poolId.incrementAndGet()}-"
private val nextId = AtomicInteger(0)
override fun newThread(r: Runnable): Thread {
val thread = delegate.newThread(r)
thread.name = "$prefix${nextId.incrementAndGet()}"
return thread
}
}

View File

@ -14,15 +14,16 @@ import org.apache.activemq.artemis.utils.ConfigurationHelper
import java.util.concurrent.Executor import java.util.concurrent.Executor
import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.ScheduledExecutorService
class NodeNettyConnectorFactory : ConnectorFactory { class CordaNettyConnectorFactory : ConnectorFactory {
override fun createConnector(configuration: MutableMap<String, Any>?, override fun createConnector(configuration: MutableMap<String, Any>?,
handler: BufferHandler?, handler: BufferHandler?,
listener: ClientConnectionLifeCycleListener?, listener: ClientConnectionLifeCycleListener?,
closeExecutor: Executor?, closeExecutor: Executor,
threadPool: Executor?, threadPool: Executor,
scheduledThreadPool: ScheduledExecutorService?, scheduledThreadPool: ScheduledExecutorService,
protocolManager: ClientProtocolManager?): Connector { protocolManager: ClientProtocolManager?): Connector {
val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Connector", configuration) val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Connector", configuration)
setThreadPoolName(threadPool, closeExecutor, scheduledThreadPool, threadPoolName)
val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration) val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration)
return NettyConnector( return NettyConnector(
configuration, configuration,
@ -31,7 +32,7 @@ class NodeNettyConnectorFactory : ConnectorFactory {
closeExecutor, closeExecutor,
threadPool, threadPool,
scheduledThreadPool, scheduledThreadPool,
MyClientProtocolManager(threadPoolName, trace) MyClientProtocolManager("$threadPoolName-netty", trace)
) )
} }
@ -39,6 +40,17 @@ class NodeNettyConnectorFactory : ConnectorFactory {
override fun getDefaults(): Map<String?, Any?> = NettyConnector.DEFAULT_CONFIG override fun getDefaults(): Map<String?, Any?> = NettyConnector.DEFAULT_CONFIG
private fun setThreadPoolName(threadPool: Executor, closeExecutor: Executor, scheduledThreadPool: ScheduledExecutorService, name: String) {
threadPool.setThreadPoolName("$name-artemis")
// Artemis will actually wrap the same backing Executor to create multiple "OrderedExecutors". In this scenerio both the threadPool
// and the closeExecutor are the same when it comes to the pool names. If however they are different then given them separate names.
if (threadPool.rootExecutor !== closeExecutor.rootExecutor) {
closeExecutor.setThreadPoolName("$name-artemis-closer")
}
// The scheduler is separate
scheduledThreadPool.setThreadPoolName("$name-artemis-scheduler")
}
private class MyClientProtocolManager(private val threadPoolName: String, private val trace: Boolean) : ActiveMQClientProtocolManager() { private class MyClientProtocolManager(private val threadPoolName: String, private val trace: Boolean) : ActiveMQClientProtocolManager() {
override fun addChannelHandlers(pipeline: ChannelPipeline) { override fun addChannelHandlers(pipeline: ChannelPipeline) {

View File

@ -0,0 +1,32 @@
@file:Suppress("LongParameterList", "MagicNumber")
package net.corda.nodeapi.internal
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.utilities.seconds
import java.time.Duration
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
/**
* Creates a [ThreadPoolExecutor] which will use a maximum of [maxPoolSize] threads at any given time and will by default idle down to 0
* threads.
*/
fun namedThreadPoolExecutor(maxPoolSize: Int,
corePoolSize: Int = 0,
idleKeepAlive: Duration = 30.seconds,
workQueue: BlockingQueue<Runnable> = LinkedBlockingQueue(),
poolName: String = "pool",
daemonThreads: Boolean = false,
threadPriority: Int = Thread.NORM_PRIORITY): ThreadPoolExecutor {
return ThreadPoolExecutor(
corePoolSize,
maxPoolSize,
idleKeepAlive.toNanos(),
TimeUnit.NANOSECONDS,
workQueue,
DefaultThreadFactory(poolName, daemonThreads, threadPriority)
)
}

View File

@ -22,6 +22,7 @@ import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
@ -31,6 +32,7 @@ import org.apache.activemq.artemis.api.core.client.ClientSession
import org.slf4j.MDC import org.slf4j.MDC
import rx.Subscription import rx.Subscription
import java.time.Duration import java.time.Duration
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.ScheduledFuture import java.util.concurrent.ScheduledFuture
@ -53,7 +55,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
maxMessageSize: Int, maxMessageSize: Int,
revocationConfig: RevocationConfig, revocationConfig: RevocationConfig,
enableSNI: Boolean, enableSNI: Boolean,
private val artemisMessageClientFactory: () -> ArtemisSessionProvider, private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider,
private val bridgeMetricsService: BridgeMetricsService? = null, private val bridgeMetricsService: BridgeMetricsService? = null,
trace: Boolean, trace: Boolean,
sslHandshakeTimeout: Duration?, sslHandshakeTimeout: Duration?,
@ -78,9 +80,11 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig,useOpenSSL, enableSNI, trace = trace, _sslHandshakeTimeout = sslHandshakeTimeout) private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig,useOpenSSL, enableSNI, trace = trace, _sslHandshakeTimeout = sslHandshakeTimeout)
private var sharedEventLoopGroup: EventLoopGroup? = null private var sharedEventLoopGroup: EventLoopGroup? = null
private var sslDelegatedTaskExecutor: ExecutorService? = null
private var artemis: ArtemisSessionProvider? = null private var artemis: ArtemisSessionProvider? = null
companion object { companion object {
private val log = contextLogger()
private const val CORDA_NUM_BRIDGE_THREADS_PROP_NAME = "net.corda.nodeapi.amqpbridgemanager.NumBridgeThreads" private const val CORDA_NUM_BRIDGE_THREADS_PROP_NAME = "net.corda.nodeapi.amqpbridgemanager.NumBridgeThreads"
@ -97,18 +101,11 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
* however Artemis and the remote Corda instanced will deduplicate these messages. * however Artemis and the remote Corda instanced will deduplicate these messages.
*/ */
@Suppress("TooManyFunctions") @Suppress("TooManyFunctions")
private class AMQPBridge(val sourceX500Name: String, private inner class AMQPBridge(val sourceX500Name: String,
val queueName: String, val queueName: String,
val targets: List<NetworkHostAndPort>, val targets: List<NetworkHostAndPort>,
val legalNames: Set<CordaX500Name>, val allowedRemoteLegalNames: Set<CordaX500Name>,
private val amqpConfig: AMQPConfiguration, private val amqpConfig: AMQPConfiguration) {
sharedEventGroup: EventLoopGroup,
private val artemis: ArtemisSessionProvider,
private val bridgeMetricsService: BridgeMetricsService?,
private val bridgeConnectionTTLSeconds: Int) {
companion object {
private val log = contextLogger()
}
private fun withMDC(block: () -> Unit) { private fun withMDC(block: () -> Unit) {
val oldMDC = MDC.getCopyOfContextMap() ?: emptyMap<String, String>() val oldMDC = MDC.getCopyOfContextMap() ?: emptyMap<String, String>()
@ -116,7 +113,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
MDC.put("queueName", queueName) MDC.put("queueName", queueName)
MDC.put("source", amqpConfig.sourceX500Name) MDC.put("source", amqpConfig.sourceX500Name)
MDC.put("targets", targets.joinToString(separator = ";") { it.toString() }) MDC.put("targets", targets.joinToString(separator = ";") { it.toString() })
MDC.put("legalNames", legalNames.joinToString(separator = ";") { it.toString() }) MDC.put("allowedRemoteLegalNames", allowedRemoteLegalNames.joinToString(separator = ";") { it.toString() })
MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString()) MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString())
block() block()
} finally { } finally {
@ -134,13 +131,18 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) } private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) }
val amqpClient = AMQPClient(targets, legalNames, amqpConfig, sharedThreadPool = sharedEventGroup) val amqpClient = AMQPClient(
targets,
allowedRemoteLegalNames,
amqpConfig,
AMQPClient.NettyThreading.Shared(sharedEventLoopGroup!!, sslDelegatedTaskExecutor!!)
)
private var session: ClientSession? = null private var session: ClientSession? = null
private var consumer: ClientConsumer? = null private var consumer: ClientConsumer? = null
private var connectedSubscription: Subscription? = null private var connectedSubscription: Subscription? = null
@Volatile @Volatile
private var messagesReceived: Boolean = false private var messagesReceived: Boolean = false
private val eventLoop: EventLoop = sharedEventGroup.next() private val eventLoop: EventLoop = sharedEventLoopGroup!!.next()
private var artemisState: ArtemisState = ArtemisState.STOPPED private var artemisState: ArtemisState = ArtemisState.STOPPED
set(value) { set(value) {
logDebugWithMDC { "State change $field to $value" } logDebugWithMDC { "State change $field to $value" }
@ -152,32 +154,9 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
private var scheduledExecutorService: ScheduledExecutorService private var scheduledExecutorService: ScheduledExecutorService
= Executors.newSingleThreadScheduledExecutor(ThreadFactoryBuilder().setNameFormat("bridge-connection-reset-%d").build()) = Executors.newSingleThreadScheduledExecutor(ThreadFactoryBuilder().setNameFormat("bridge-connection-reset-%d").build())
@Suppress("ClassNaming")
private sealed class ArtemisState {
object STARTING : ArtemisState()
data class STARTED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
object CHECKING : ArtemisState()
object RESTARTED : ArtemisState()
object RECEIVING : ArtemisState()
object AMQP_STOPPED : ArtemisState()
object AMQP_STARTING : ArtemisState()
object AMQP_STARTED : ArtemisState()
object AMQP_RESTARTED : ArtemisState()
object STOPPING : ArtemisState()
object STOPPED : ArtemisState()
data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
open val pending: ScheduledFuture<Unit>? = null
override fun toString(): String = javaClass.simpleName
}
private fun artemis(inProgress: ArtemisState, block: (precedingState: ArtemisState) -> ArtemisState) { private fun artemis(inProgress: ArtemisState, block: (precedingState: ArtemisState) -> ArtemisState) {
val runnable = { val runnable = {
synchronized(artemis) { synchronized(artemis!!) {
try { try {
val precedingState = artemisState val precedingState = artemisState
artemisState.pending?.cancel(false) artemisState.pending?.cancel(false)
@ -231,7 +210,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
} }
ArtemisState.STOPPING ArtemisState.STOPPING
} }
bridgeMetricsService?.bridgeDisconnected(targets, legalNames) bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
connectedSubscription?.unsubscribe() connectedSubscription?.unsubscribe()
connectedSubscription = null connectedSubscription = null
// Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first. // Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first.
@ -243,7 +222,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
if (connected) { if (connected) {
logInfoWithMDC("Bridge Connected") logInfoWithMDC("Bridge Connected")
bridgeMetricsService?.bridgeConnected(targets, legalNames) bridgeMetricsService?.bridgeConnected(targets, allowedRemoteLegalNames)
if (bridgeConnectionTTLSeconds > 0) { if (bridgeConnectionTTLSeconds > 0) {
// AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval // AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval
amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS, amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS,
@ -253,7 +232,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
} }
} }
artemis(ArtemisState.STARTING) { artemis(ArtemisState.STARTING) {
val startedArtemis = artemis.started val startedArtemis = artemis!!.started
if (startedArtemis == null) { if (startedArtemis == null) {
logInfoWithMDC("Bridge Connected but Artemis is disconnected") logInfoWithMDC("Bridge Connected but Artemis is disconnected")
ArtemisState.STOPPED ArtemisState.STOPPED
@ -286,7 +265,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
logInfoWithMDC("Bridge Disconnected") logInfoWithMDC("Bridge Disconnected")
amqpRestartEvent?.cancel(false) amqpRestartEvent?.cancel(false)
if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) { if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) {
bridgeMetricsService?.bridgeDisconnected(targets, legalNames) bridgeMetricsService?.bridgeDisconnected(targets, allowedRemoteLegalNames)
} }
artemis(ArtemisState.STOPPING) { precedingState: ArtemisState -> artemis(ArtemisState.STOPPING) { precedingState: ArtemisState ->
logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected") logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected")
@ -418,10 +397,10 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
properties[key] = value properties[key] = value
} }
} }
logDebugWithMDC { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" } logDebugWithMDC { "Bridged Send to ${allowedRemoteLegalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
val peerInbox = translateLocalQueueToInboxAddress(queueName) val peerInbox = translateLocalQueueToInboxAddress(queueName)
val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox, val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox,
legalNames.first().toString(), allowedRemoteLegalNames.first().toString(),
properties) properties)
sendableMessage.onComplete.then { sendableMessage.onComplete.then {
logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" } logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" }
@ -457,6 +436,29 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
} }
} }
@Suppress("ClassNaming")
private sealed class ArtemisState {
object STARTING : ArtemisState()
data class STARTED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
object CHECKING : ArtemisState()
object RESTARTED : ArtemisState()
object RECEIVING : ArtemisState()
object AMQP_STOPPED : ArtemisState()
object AMQP_STARTING : ArtemisState()
object AMQP_STARTED : ArtemisState()
object AMQP_RESTARTED : ArtemisState()
object STOPPING : ArtemisState()
object STOPPED : ArtemisState()
data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
open val pending: ScheduledFuture<Unit>? = null
override fun toString(): String = javaClass.simpleName
}
override fun deployBridge(sourceX500Name: String, queueName: String, targets: List<NetworkHostAndPort>, legalNames: Set<CordaX500Name>) { override fun deployBridge(sourceX500Name: String, queueName: String, targets: List<NetworkHostAndPort>, legalNames: Set<CordaX500Name>) {
lock.withLock { lock.withLock {
val bridges = queueNamesToBridgesMap.getOrPut(queueName) { mutableListOf() } val bridges = queueNamesToBridgesMap.getOrPut(queueName) { mutableListOf() }
@ -467,8 +469,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
} }
val newAMQPConfig = with(amqpConfig) { AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, val newAMQPConfig = with(amqpConfig) { AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize,
revocationConfig, useOpenSsl, enableSNI, sourceX500Name, trace, sslHandshakeTimeout) } revocationConfig, useOpenSsl, enableSNI, sourceX500Name, trace, sslHandshakeTimeout) }
val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig, sharedEventLoopGroup!!, artemis!!, val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig)
bridgeMetricsService, bridgeConnectionTTLSeconds)
bridges += newBridge bridges += newBridge
bridgeMetricsService?.bridgeCreated(targets, legalNames) bridgeMetricsService?.bridgeCreated(targets, legalNames)
newBridge newBridge
@ -486,7 +487,7 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
queueNamesToBridgesMap.remove(queueName) queueNamesToBridgesMap.remove(queueName)
} }
bridge.stop() bridge.stop()
bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.legalNames) bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.allowedRemoteLegalNames)
} }
} }
} }
@ -497,15 +498,16 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
// queueNamesToBridgesMap returns a mutable list, .toList converts it to a immutable list so it won't be changed by the [destroyBridge] method. // queueNamesToBridgesMap returns a mutable list, .toList converts it to a immutable list so it won't be changed by the [destroyBridge] method.
val bridges = queueNamesToBridgesMap[queueName]?.toList() val bridges = queueNamesToBridgesMap[queueName]?.toList()
destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList()) destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList())
bridges?.map { bridges?.associate {
it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.legalNames.toList(), serviceAddress = false) it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.allowedRemoteLegalNames.toList(), serviceAddress = false)
}?.toMap() ?: emptyMap() } ?: emptyMap()
} }
} }
override fun start() { override fun start() {
sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS, DefaultThreadFactory("AMQPBridge", Thread.MAX_PRIORITY)) sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS, DefaultThreadFactory("NettyBridge", Thread.MAX_PRIORITY))
val artemis = artemisMessageClientFactory() sslDelegatedTaskExecutor = sslDelegatedTaskExecutor("NettyBridge")
val artemis = artemisMessageClientFactory("ArtemisBridge")
this.artemis = artemis this.artemis = artemis
artemis.start() artemis.start()
} }
@ -522,6 +524,8 @@ open class AMQPBridgeManager(keyStore: CertificateStore,
sharedEventLoopGroup = null sharedEventLoopGroup = null
queueNamesToBridgesMap.clear() queueNamesToBridgesMap.clear()
artemis?.stop() artemis?.stop()
sslDelegatedTaskExecutor?.shutdown()
sslDelegatedTaskExecutor = null
} }
} }
} }

View File

@ -35,7 +35,7 @@ class BridgeControlListener(private val keyStore: CertificateStore,
maxMessageSize: Int, maxMessageSize: Int,
revocationConfig: RevocationConfig, revocationConfig: RevocationConfig,
enableSNI: Boolean, enableSNI: Boolean,
private val artemisMessageClientFactory: () -> ArtemisSessionProvider, private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider,
bridgeMetricsService: BridgeMetricsService? = null, bridgeMetricsService: BridgeMetricsService? = null,
trace: Boolean = false, trace: Boolean = false,
sslHandshakeTimeout: Duration? = null, sslHandshakeTimeout: Duration? = null,
@ -80,7 +80,7 @@ class BridgeControlListener(private val keyStore: CertificateStore,
bridgeNotifyQueue = "$BRIDGE_NOTIFY.$queueDisambiguityId" bridgeNotifyQueue = "$BRIDGE_NOTIFY.$queueDisambiguityId"
bridgeManager.start() bridgeManager.start()
val artemis = artemisMessageClientFactory() val artemis = artemisMessageClientFactory("BridgeControl")
this.artemis = artemis this.artemis = artemis
artemis.start() artemis.start()
val artemisClient = artemis.started!! val artemisClient = artemis.started!!

View File

@ -37,7 +37,7 @@ class LoopbackBridgeManager(keyStore: CertificateStore,
maxMessageSize: Int, maxMessageSize: Int,
revocationConfig: RevocationConfig, revocationConfig: RevocationConfig,
enableSNI: Boolean, enableSNI: Boolean,
private val artemisMessageClientFactory: () -> ArtemisSessionProvider, private val artemisMessageClientFactory: (String) -> ArtemisSessionProvider,
private val bridgeMetricsService: BridgeMetricsService? = null, private val bridgeMetricsService: BridgeMetricsService? = null,
private val isLocalInbox: (String) -> Boolean, private val isLocalInbox: (String) -> Boolean,
trace: Boolean, trace: Boolean,
@ -204,7 +204,7 @@ class LoopbackBridgeManager(keyStore: CertificateStore,
override fun start() { override fun start() {
super.start() super.start()
val artemis = artemisMessageClientFactory() val artemis = artemisMessageClientFactory("LoopbackBridge")
this.artemis = artemis this.artemis = artemis
artemis.start() artemis.start()
} }

View File

@ -5,16 +5,37 @@ package net.corda.nodeapi.internal.crypto
import net.corda.core.CordaOID import net.corda.core.CordaOID
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.crypto.newSecureRandom import net.corda.core.crypto.newSecureRandom
import net.corda.core.internal.* import net.corda.core.internal.CertRole
import net.corda.core.internal.SignedDataWithCert
import net.corda.core.internal.reader
import net.corda.core.internal.signWithCert
import net.corda.core.internal.uncheckedCast
import net.corda.core.internal.validate
import net.corda.core.internal.writer
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.core.utilities.millis import net.corda.core.utilities.millis
import net.corda.core.utilities.toHex import net.corda.core.utilities.toHex
import net.corda.nodeapi.internal.protonwrapper.netty.distributionPointsToString import net.corda.nodeapi.internal.protonwrapper.netty.distributionPointsToString
import org.bouncycastle.asn1.* import org.bouncycastle.asn1.ASN1EncodableVector
import org.bouncycastle.asn1.ASN1ObjectIdentifier
import org.bouncycastle.asn1.ASN1Sequence
import org.bouncycastle.asn1.DERSequence
import org.bouncycastle.asn1.DERUTF8String
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.style.BCStyle import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x509.* import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
import org.bouncycastle.asn1.x509.BasicConstraints
import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.DistributionPoint
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.KeyPurposeId
import org.bouncycastle.asn1.x509.KeyUsage
import org.bouncycastle.asn1.x509.NameConstraints
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import org.bouncycastle.cert.X509CertificateHolder import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.X509v3CertificateBuilder import org.bouncycastle.cert.X509v3CertificateBuilder
import org.bouncycastle.cert.bc.BcX509ExtensionUtils import org.bouncycastle.cert.bc.BcX509ExtensionUtils
@ -32,8 +53,13 @@ import java.nio.file.Path
import java.security.KeyPair import java.security.KeyPair
import java.security.PublicKey import java.security.PublicKey
import java.security.SignatureException import java.security.SignatureException
import java.security.cert.* import java.security.cert.CertPath
import java.security.cert.Certificate import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.CertificateFactory
import java.security.cert.TrustAnchor
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.time.Duration import java.time.Duration
import java.time.Instant import java.time.Instant
import java.time.temporal.ChronoUnit import java.time.temporal.ChronoUnit
@ -359,7 +385,7 @@ object X509Utilities {
private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) { private fun addCrlInfo(builder: X509v3CertificateBuilder, crlDistPoint: String?, crlIssuer: X500Name?) {
if (crlDistPoint != null) { if (crlDistPoint != null) {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint))) val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier))
val crlIssuerGeneralNames = crlIssuer?.let { val crlIssuerGeneralNames = crlIssuer?.let {
GeneralNames(GeneralName(crlIssuer)) GeneralNames(GeneralName(crlIssuer))
} }
@ -379,6 +405,8 @@ object X509Utilities {
bytes[0] = bytes[0].and(0x3F).or(0x40) bytes[0] = bytes[0].and(0x3F).or(0x40)
return BigInteger(bytes) return BigInteger(bytes)
} }
fun toGeneralNames(string: String, tag: Int = GeneralName.directoryName): GeneralNames = GeneralNames(GeneralName(tag, string))
} }
// Assuming cert type to role is 1:1 // Assuming cert type to role is 1:1

View File

@ -27,16 +27,17 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME
import net.corda.nodeapi.internal.requireMessageSize import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.lang.Long.min import java.lang.Long.min
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.security.cert.CertPathValidatorException import java.util.concurrent.Executor
import java.util.concurrent.ExecutorService
import java.util.concurrent.ThreadPoolExecutor
import java.time.Duration import java.time.Duration
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
enum class ProxyVersion { enum class ProxyVersion {
@ -63,8 +64,8 @@ data class ProxyConfig(val version: ProxyVersion, val proxyAddress: NetworkHostA
class AMQPClient(private val targets: List<NetworkHostAndPort>, class AMQPClient(private val targets: List<NetworkHostAndPort>,
val allowedRemoteLegalNames: Set<CordaX500Name>, val allowedRemoteLegalNames: Set<CordaX500Name>,
private val configuration: AMQPConfiguration, private val configuration: AMQPConfiguration,
private val sharedThreadPool: EventLoopGroup? = null, private val nettyThreading: NettyThreading = NettyThreading.NonShared("AMQPClient"),
private val threadPoolName: String = "AMQPClient") : AutoCloseable { private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON) : AutoCloseable {
companion object { companion object {
init { init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
@ -84,14 +85,12 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
private val lock = ReentrantLock() private val lock = ReentrantLock()
@Volatile @Volatile
private var started: Boolean = false private var started: Boolean = false
private var workerGroup: EventLoopGroup? = null
@Volatile @Volatile
private var clientChannel: Channel? = null private var clientChannel: Channel? = null
// Offset into the list of targets, so that we can implement round-robin reconnect logic. // Offset into the list of targets, so that we can implement round-robin reconnect logic.
private var targetIndex = 0 private var targetIndex = 0
private var currentTarget: NetworkHostAndPort = targets.first() private var currentTarget: NetworkHostAndPort = targets.first()
private var retryInterval = MIN_RETRY_INTERVAL private var retryInterval = MIN_RETRY_INTERVAL
private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker()
private val handshakeFailureRetryTargets = mutableSetOf<NetworkHostAndPort>() private val handshakeFailureRetryTargets = mutableSetOf<NetworkHostAndPort>()
private var retryingHandshakeFailures = false private var retryingHandshakeFailures = false
private var retryOffset = 0 private var retryOffset = 0
@ -172,7 +171,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
log.info("Failed to connect to $currentTarget", future.cause()) log.info("Failed to connect to $currentTarget", future.cause())
if (started) { if (started) {
workerGroup?.schedule({ nettyThreading.eventLoopGroup.schedule({
nextTarget() nextTarget()
restart() restart()
}, retryInterval, TimeUnit.MILLISECONDS) }, retryInterval, TimeUnit.MILLISECONDS)
@ -191,7 +190,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
clientChannel = null clientChannel = null
if (started && !amqpActive) { if (started && !amqpActive) {
log.debug { "Scheduling restart of $currentTarget (AMQP inactive)" } log.debug { "Scheduling restart of $currentTarget (AMQP inactive)" }
workerGroup?.schedule({ nettyThreading.eventLoopGroup.schedule({
nextTarget() nextTarget()
restart() restart()
}, retryInterval, TimeUnit.MILLISECONDS) }, retryInterval, TimeUnit.MILLISECONDS)
@ -199,17 +198,16 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
} }
private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer<SocketChannel>() { private class ClientChannelInitializer(val parent: AMQPClient) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration private val conf = parent.configuration
@Volatile @Volatile
private lateinit var amqpChannelHandler: AMQPChannelHandler private lateinit var amqpChannelHandler: AMQPChannelHandler
init {
keyManagerFactory.init(conf.keyStore)
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker))
}
@Suppress("ComplexMethod") @Suppress("ComplexMethod")
override fun initChannel(ch: SocketChannel) { override fun initChannel(ch: SocketChannel) {
val pipeline = ch.pipeline() val pipeline = ch.pipeline()
@ -249,9 +247,22 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration) val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration)
val target = parent.currentTarget val target = parent.currentTarget
val handler = if (parent.configuration.useOpenSsl) { val handler = if (parent.configuration.useOpenSsl) {
createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc()) createClientOpenSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
ch.alloc(),
parent.nettyThreading.sslDelegatedTaskExecutor
)
} else { } else {
createClientSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory) createClientSslHandler(
target,
parent.allowedRemoteLegalNames,
wrappedKeyManagerFactory,
trustManagerFactory,
parent.nettyThreading.sslDelegatedTaskExecutor
)
} }
handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis() handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout.toMillis()
pipeline.addLast("sslHandler", handler) pipeline.addLast("sslHandler", handler)
@ -292,7 +303,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
if (started && amqpActive) { if (started && amqpActive) {
log.debug { "Scheduling restart of $currentTarget (AMQP active)" } log.debug { "Scheduling restart of $currentTarget (AMQP active)" }
workerGroup?.schedule({ nettyThreading.eventLoopGroup.schedule({
nextTarget() nextTarget()
restart() restart()
}, retryInterval, TimeUnit.MILLISECONDS) }, retryInterval, TimeUnit.MILLISECONDS)
@ -309,7 +320,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
return return
} }
log.info("Connect to: $currentTarget") log.info("Connect to: $currentTarget")
workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY)) (nettyThreading as? NettyThreading.NonShared)?.start()
started = true started = true
restart() restart()
} }
@ -321,7 +332,7 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
} }
val bootstrap = Bootstrap() val bootstrap = Bootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux // TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
bootstrap.group(workerGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this)) bootstrap.group(nettyThreading.eventLoopGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this))
// Delegate DNS Resolution to the proxy side, if we are using proxy. // Delegate DNS Resolution to the proxy side, if we are using proxy.
if (configuration.proxyConfig != null) { if (configuration.proxyConfig != null) {
bootstrap.resolver(NoopAddressResolverGroup.INSTANCE) bootstrap.resolver(NoopAddressResolverGroup.INSTANCE)
@ -335,14 +346,12 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
lock.withLock { lock.withLock {
log.info("Stopping connection to: $currentTarget, Local address: $localAddressString") log.info("Stopping connection to: $currentTarget, Local address: $localAddressString")
started = false started = false
if (sharedThreadPool == null) { if (nettyThreading is NettyThreading.NonShared) {
workerGroup?.shutdownGracefully() nettyThreading.stop()
workerGroup?.terminationFuture()?.sync()
} else { } else {
clientChannel?.close()?.sync() clientChannel?.close()?.sync()
} }
clientChannel = null clientChannel = null
workerGroup = null
log.info("Stopped connection to $currentTarget") log.info("Stopped connection to $currentTarget")
} }
} }
@ -384,5 +393,35 @@ class AMQPClient(private val targets: List<NetworkHostAndPort>,
val onConnection: Observable<ConnectionChange> val onConnection: Observable<ConnectionChange>
get() = _onConnection get() = _onConnection
val softFailExceptions: List<CertPathValidatorException> get() = revocationChecker.softFailExceptions
sealed class NettyThreading {
abstract val eventLoopGroup: EventLoopGroup
abstract val sslDelegatedTaskExecutor: Executor
class Shared(override val eventLoopGroup: EventLoopGroup,
override val sslDelegatedTaskExecutor: ExecutorService = sslDelegatedTaskExecutor("AMQPClient")) : NettyThreading()
class NonShared(val threadPoolName: String) : NettyThreading() {
private var _eventLoopGroup: NioEventLoopGroup? = null
override val eventLoopGroup: EventLoopGroup get() = checkNotNull(_eventLoopGroup)
private var _sslDelegatedTaskExecutor: ThreadPoolExecutor? = null
override val sslDelegatedTaskExecutor: ExecutorService get() = checkNotNull(_sslDelegatedTaskExecutor)
fun start() {
check(_eventLoopGroup == null)
check(_sslDelegatedTaskExecutor == null)
_eventLoopGroup = NioEventLoopGroup(NUM_CLIENT_THREADS, DefaultThreadFactory(threadPoolName, Thread.MAX_PRIORITY))
_sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
}
fun stop() {
eventLoopGroup.shutdownGracefully()
eventLoopGroup.terminationFuture().sync()
sslDelegatedTaskExecutor.shutdown()
_eventLoopGroup = null
_sslDelegatedTaskExecutor = null
}
}
}
} }

View File

@ -21,16 +21,15 @@ import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
import net.corda.nodeapi.internal.requireMessageSize import net.corda.nodeapi.internal.requireMessageSize
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import org.apache.qpid.proton.engine.Delivery import org.apache.qpid.proton.engine.Delivery
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.net.BindException import java.net.BindException
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.security.cert.CertPathValidatorException
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.TrustManagerFactory
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
/** /**
@ -39,37 +38,34 @@ import kotlin.concurrent.withLock
class AMQPServer(val hostName: String, class AMQPServer(val hostName: String,
val port: Int, val port: Int,
private val configuration: AMQPConfiguration, private val configuration: AMQPConfiguration,
private val threadPoolName: String = "AMQPServer") : AutoCloseable { private val threadPoolName: String = "AMQPServer",
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON,
private val remotingThreads: Int? = null) : AutoCloseable {
companion object { companion object {
init { init {
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE) InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
} }
private const val CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME = "net.corda.nodeapi.amqpserver.NumServerThreads"
private val log = contextLogger() private val log = contextLogger()
private val NUM_SERVER_THREADS = Integer.getInteger(CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME, 4) private val DEFAULT_REMOTING_THREADS = Integer.getInteger("net.corda.nodeapi.amqpserver.NumServerThreads", 4)
} }
private val lock = ReentrantLock() private val lock = ReentrantLock()
@Volatile
private var stopping: Boolean = false
private var bossGroup: EventLoopGroup? = null private var bossGroup: EventLoopGroup? = null
private var workerGroup: EventLoopGroup? = null private var workerGroup: EventLoopGroup? = null
private var serverChannel: Channel? = null private var serverChannel: Channel? = null
private val revocationChecker = configuration.revocationConfig.createPKIXRevocationChecker() private var sslDelegatedTaskExecutor: ExecutorService? = null
private val clientChannels = ConcurrentHashMap<InetSocketAddress, SocketChannel>() private val clientChannels = ConcurrentHashMap<InetSocketAddress, SocketChannel>()
private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer<SocketChannel>() { private class ServerChannelInitializer(val parent: AMQPServer) : ChannelInitializer<SocketChannel>() {
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) private val keyManagerFactory = keyManagerFactory(parent.configuration.keyStore)
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) private val trustManagerFactory = trustManagerFactoryWithRevocation(
parent.configuration.trustStore,
parent.configuration.revocationConfig,
parent.distPointCrlSource
)
private val conf = parent.configuration private val conf = parent.configuration
init {
keyManagerFactory.init(conf.keyStore.value.internal, conf.keyStore.entryPassword.toCharArray())
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, parent.revocationChecker))
}
override fun initChannel(ch: SocketChannel) { override fun initChannel(ch: SocketChannel) {
val amqpConfiguration = parent.configuration val amqpConfiguration = parent.configuration
val pipeline = ch.pipeline() val pipeline = ch.pipeline()
@ -116,11 +112,12 @@ class AMQPServer(val hostName: String,
Pair(createServerSNIOpenSniHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap) Pair(createServerSNIOpenSniHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap)
} else { } else {
val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig) val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig)
val delegatedTaskExecutor = checkNotNull(parent.sslDelegatedTaskExecutor)
val handler = if (amqpConfig.useOpenSsl) { val handler = if (amqpConfig.useOpenSsl) {
createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc()) createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc(), delegatedTaskExecutor)
} else { } else {
// For javaSSL, SNI matching is handled at key manager level. // For javaSSL, SNI matching is handled at key manager level.
createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory) createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory, delegatedTaskExecutor)
} }
handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout.toMillis() handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout.toMillis()
Pair(handler, mapOf(DEFAULT to keyManagerFactory)) Pair(handler, mapOf(DEFAULT to keyManagerFactory))
@ -132,8 +129,13 @@ class AMQPServer(val hostName: String,
lock.withLock { lock.withLock {
stop() stop()
sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("$threadPoolName-boss", Thread.MAX_PRIORITY)) bossGroup = NioEventLoopGroup(1, DefaultThreadFactory("$threadPoolName-boss", Thread.MAX_PRIORITY))
workerGroup = NioEventLoopGroup(NUM_SERVER_THREADS, DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY)) workerGroup = NioEventLoopGroup(
remotingThreads ?: DEFAULT_REMOTING_THREADS,
DefaultThreadFactory("$threadPoolName-worker", Thread.MAX_PRIORITY)
)
val server = ServerBootstrap() val server = ServerBootstrap()
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux // TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
@ -154,22 +156,19 @@ class AMQPServer(val hostName: String,
fun stop() { fun stop() {
lock.withLock { lock.withLock {
try { serverChannel?.close()
stopping = true serverChannel = null
serverChannel?.apply { close() }
serverChannel = null
workerGroup?.shutdownGracefully() workerGroup?.shutdownGracefully()
workerGroup?.terminationFuture()?.sync() workerGroup?.terminationFuture()?.sync()
workerGroup = null
bossGroup?.shutdownGracefully() bossGroup?.shutdownGracefully()
bossGroup?.terminationFuture()?.sync() bossGroup?.terminationFuture()?.sync()
bossGroup = null
workerGroup = null sslDelegatedTaskExecutor?.shutdown()
bossGroup = null sslDelegatedTaskExecutor = null
} finally {
stopping = false
}
} }
} }
@ -226,6 +225,4 @@ class AMQPServer(val hostName: String,
private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized() private val _onConnection = PublishSubject.create<ConnectionChange>().toSerialized()
val onConnection: Observable<ConnectionChange> val onConnection: Observable<ConnectionChange>
get() = _onConnection get() = _onConnection
val softFailExceptions: List<CertPathValidatorException> get() = revocationChecker.softFailExceptions
} }

View File

@ -31,4 +31,6 @@ object AllowAllRevocationChecker : PKIXRevocationChecker() {
override fun getSoftFailExceptions(): List<CertPathValidatorException> { override fun getSoftFailExceptions(): List<CertPathValidatorException> {
return Collections.emptyList() return Collections.emptyList()
} }
override fun clone(): AllowAllRevocationChecker = this
} }

View File

@ -3,9 +3,6 @@ package net.corda.nodeapi.internal.protonwrapper.netty
import com.typesafe.config.Config import com.typesafe.config.Config
import net.corda.nodeapi.internal.config.ConfigParser import net.corda.nodeapi.internal.config.ConfigParser
import net.corda.nodeapi.internal.config.CustomConfigParser import net.corda.nodeapi.internal.config.CustomConfigParser
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import java.security.cert.PKIXRevocationChecker
/** /**
* Data structure for controlling the way how Certificate Revocation Lists are handled. * Data structure for controlling the way how Certificate Revocation Lists are handled.
@ -45,18 +42,6 @@ interface RevocationConfig {
* Optional [CrlSource] which only makes sense with `mode` = `EXTERNAL_SOURCE` * Optional [CrlSource] which only makes sense with `mode` = `EXTERNAL_SOURCE`
*/ */
val externalCrlSource: CrlSource? val externalCrlSource: CrlSource?
fun createPKIXRevocationChecker(): PKIXRevocationChecker {
return when (mode) {
Mode.OFF -> AllowAllRevocationChecker
Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(externalCrlSource) { "externalCrlSource must be specfied for EXTERNAL_SOURCE" }
CordaRevocationChecker(externalCrlSource, softFail = true)
}
Mode.SOFT_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = true)
Mode.HARD_FAIL -> CordaRevocationChecker(CertDistPointCrlSource(), softFail = false)
}
}
} }
/** /**

View File

@ -1,3 +1,5 @@
@file:Suppress("ComplexMethod", "LongParameterList")
package net.corda.nodeapi.internal.protonwrapper.netty package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.buffer.ByteBufAllocator import io.netty.buffer.ByteBufAllocator
@ -18,6 +20,8 @@ import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.toSimpleString import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.crypto.x509 import net.corda.nodeapi.internal.crypto.x509
import net.corda.nodeapi.internal.namedThreadPoolExecutor
import net.corda.nodeapi.internal.revocation.CordaRevocationChecker
import org.bouncycastle.asn1.ASN1InputStream import org.bouncycastle.asn1.ASN1InputStream
import org.bouncycastle.asn1.ASN1Primitive import org.bouncycastle.asn1.ASN1Primitive
import org.bouncycastle.asn1.ASN1IA5String import org.bouncycastle.asn1.ASN1IA5String
@ -34,10 +38,10 @@ import java.net.URI
import java.security.KeyStore import java.security.KeyStore
import java.security.cert.CertificateException import java.security.cert.CertificateException
import java.security.cert.PKIXBuilderParameters import java.security.cert.PKIXBuilderParameters
import java.security.cert.PKIXRevocationChecker
import java.security.cert.X509CertSelector import java.security.cert.X509CertSelector
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.concurrent.Executor import java.util.concurrent.Executor
import java.util.concurrent.ThreadPoolExecutor
import javax.net.ssl.CertPathTrustManagerParameters import javax.net.ssl.CertPathTrustManagerParameters
import javax.net.ssl.KeyManagerFactory import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName import javax.net.ssl.SNIHostName
@ -46,7 +50,6 @@ import javax.net.ssl.SSLEngine
import javax.net.ssl.TrustManagerFactory import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509ExtendedTrustManager import javax.net.ssl.X509ExtendedTrustManager
import javax.security.auth.x500.X500Principal import javax.security.auth.x500.X500Principal
import kotlin.system.measureTimeMillis
private const val HOSTNAME_FORMAT = "%s.corda.net" private const val HOSTNAME_FORMAT = "%s.corda.net"
internal const val DEFAULT = "default" internal const val DEFAULT = "default"
@ -58,7 +61,6 @@ internal val logger = LoggerFactory.getLogger("net.corda.nodeapi.internal.proton
/** /**
* Returns all the CRL distribution points in the certificate as [URI]s along with the CRL issuer names, if any. * Returns all the CRL distribution points in the certificate as [URI]s along with the CRL issuer names, if any.
*/ */
@Suppress("ComplexMethod")
fun X509Certificate.distributionPoints(): Map<URI, List<X500Principal>?> { fun X509Certificate.distributionPoints(): Map<URI, List<X500Principal>?> {
logger.debug { "Checking CRLDPs for $subjectX500Principal" } logger.debug { "Checking CRLDPs for $subjectX500Principal" }
@ -117,6 +119,14 @@ fun certPathToString(certPath: Array<out X509Certificate>?): String {
return certPath.joinToString(System.lineSeparator()) { " ${it.toSimpleString()}" } return certPath.joinToString(System.lineSeparator()) { " ${it.toSimpleString()}" }
} }
/**
* Create an executor for processing SSL handshake tasks asynchronously (see [SSLEngine.getDelegatedTask]). The max number of threads is 3,
* which is the typical number of CRLs expected in a Corda TLS cert path. The executor needs to be passed to the [SslHandler] constructor.
*/
fun sslDelegatedTaskExecutor(parentPoolName: String): ThreadPoolExecutor {
return namedThreadPoolExecutor(maxPoolSize = 3, poolName = "$parentPoolName-ssltask")
}
@VisibleForTesting @VisibleForTesting
class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() { class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() {
companion object { companion object {
@ -179,32 +189,11 @@ class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509Ex
} }
private object LoggingImmediateExecutor : Executor {
override fun execute(command: Runnable) {
val log = LoggerFactory.getLogger(javaClass)
@Suppress("TooGenericExceptionCaught", "MagicNumber") // log and rethrow all exceptions
try {
val commandName = command::class.qualifiedName?.let { "[$it]" } ?: ""
log.debug("Entering SSL command $commandName")
val elapsedTime = measureTimeMillis { command.run() }
log.debug("Exiting SSL command $elapsedTime millis")
if (elapsedTime > 100) {
log.info("Command: $commandName took $elapsedTime millis to execute")
}
}
catch (ex: Exception) {
log.error("Caught exception in SSL handler executor", ex)
throw ex
}
}
}
internal fun createClientSslHandler(target: NetworkHostAndPort, internal fun createClientSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>, expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory, keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler { trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory) val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine(target.host, target.port) val sslEngine = sslContext.createSSLEngine(target.host, target.port)
sslEngine.useClientMode = true sslEngine.useClientMode = true
@ -216,14 +205,15 @@ internal fun createClientSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters sslEngine.sslParameters = sslParameters
} }
return SslHandler(sslEngine, false, LoggingImmediateExecutor) return SslHandler(sslEngine, false, delegateTaskExecutor)
} }
internal fun createClientOpenSslHandler(target: NetworkHostAndPort, internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
expectedRemoteLegalNames: Set<CordaX500Name>, expectedRemoteLegalNames: Set<CordaX500Name>,
keyManagerFactory: KeyManagerFactory, keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory, trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler { alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build() val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build()
val sslEngine = sslContext.newEngine(alloc, target.host, target.port) val sslEngine = sslContext.newEngine(alloc, target.host, target.port)
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray() sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
@ -233,12 +223,13 @@ internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single()))) sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
sslEngine.sslParameters = sslParameters sslEngine.sslParameters = sslParameters
} }
return SslHandler(sslEngine, false, LoggingImmediateExecutor) return SslHandler(sslEngine, false, delegateTaskExecutor)
} }
internal fun createServerSslHandler(keyStore: CertificateStore, internal fun createServerSslHandler(keyStore: CertificateStore,
keyManagerFactory: KeyManagerFactory, keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory): SslHandler { trustManagerFactory: TrustManagerFactory,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory) val sslContext = createAndInitSslContext(keyManagerFactory, trustManagerFactory)
val sslEngine = sslContext.createSSLEngine() val sslEngine = sslContext.createSSLEngine()
sslEngine.useClientMode = false sslEngine.useClientMode = false
@ -249,39 +240,29 @@ internal fun createServerSslHandler(keyStore: CertificateStore,
val sslParameters = sslEngine.sslParameters val sslParameters = sslEngine.sslParameters
sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore)) sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore))
sslEngine.sslParameters = sslParameters sslEngine.sslParameters = sslParameters
return SslHandler(sslEngine, false, LoggingImmediateExecutor) return SslHandler(sslEngine, false, delegateTaskExecutor)
} }
internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory, internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory,
trustManagerFactory: TrustManagerFactory, trustManagerFactory: TrustManagerFactory,
alloc: ByteBufAllocator): SslHandler { alloc: ByteBufAllocator,
delegateTaskExecutor: Executor): SslHandler {
val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build() val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build()
val sslEngine = sslContext.newEngine(alloc) val sslEngine = sslContext.newEngine(alloc)
sslEngine.useClientMode = false sslEngine.useClientMode = false
return SslHandler(sslEngine, false, LoggingImmediateExecutor) return SslHandler(sslEngine, false, delegateTaskExecutor)
} }
fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SSLContext { fun createAndInitSslContext(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory?): SSLContext {
val sslContext = SSLContext.getInstance("TLS") val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers val trustManagers = trustManagerFactory
val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java) ?.trustManagers
.map { LoggingTrustManagerWrapper(it) }.toTypedArray() ?.map { if (it is X509ExtendedTrustManager) LoggingTrustManagerWrapper(it) else it }
sslContext.init(keyManagers, trustManagers, newSecureRandom()) ?.toTypedArray()
sslContext.init(keyManagerFactory.keyManagers, trustManagers, newSecureRandom())
return sslContext return sslContext
} }
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore,
revocationConfig: RevocationConfig): CertPathTrustManagerParameters {
return initialiseTrustStoreAndEnableCrlChecking(trustStore, revocationConfig.createPKIXRevocationChecker())
}
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore,
revocationChecker: PKIXRevocationChecker): CertPathTrustManagerParameters {
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
return CertPathTrustManagerParameters(pkixParams)
}
/** /**
* Creates a special SNI handler used only when openSSL is used for AMQPServer * Creates a special SNI handler used only when openSSL is used for AMQPServer
*/ */
@ -296,14 +277,13 @@ internal fun createServerSNIOpenSniHandler(keyManagerFactoriesMap: Map<String, K
return SniHandler(mapping.build()) return SniHandler(mapping.build())
} }
@Suppress("SpreadOperator")
private fun getServerSslContextBuilder(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslContextBuilder { private fun getServerSslContextBuilder(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslContextBuilder {
return SslContextBuilder.forServer(keyManagerFactory) return SslContextBuilder.forServer(keyManagerFactory)
.sslProvider(SslProvider.OPENSSL) .sslProvider(SslProvider.OPENSSL)
.trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)) .trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory))
.clientAuth(ClientAuth.REQUIRE) .clientAuth(ClientAuth.REQUIRE)
.ciphers(ArtemisTcpTransport.CIPHER_SUITES) .ciphers(ArtemisTcpTransport.CIPHER_SUITES)
.protocols(*ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()) .protocols(ArtemisTcpTransport.TLS_VERSIONS)
} }
internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKeyManagerFactoryWrapper> { internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKeyManagerFactoryWrapper> {
@ -327,7 +307,38 @@ internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKe
// 2nd parameter `password` - the password for recovering keys in the KeyStore // 2nd parameter `password` - the password for recovering keys in the KeyStore
fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.entryPassword.toCharArray()) fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.entryPassword.toCharArray())
fun TrustManagerFactory.init(trustStore: CertificateStore) = init(trustStore.value.internal) fun keyManagerFactory(keyStore: CertificateStore): KeyManagerFactory {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore)
return keyManagerFactory
}
fun trustManagerFactory(trustStore: CertificateStore): TrustManagerFactory {
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(trustStore.value.internal)
return trustManagerFactory
}
fun trustManagerFactoryWithRevocation(trustStore: CertificateStore,
revocationConfig: RevocationConfig,
crlSource: CrlSource): TrustManagerFactory {
val revocationChecker = when (revocationConfig.mode) {
RevocationConfig.Mode.OFF -> AllowAllRevocationChecker
RevocationConfig.Mode.EXTERNAL_SOURCE -> {
val externalCrlSource = requireNotNull(revocationConfig.externalCrlSource) {
"externalCrlSource must be specfied for EXTERNAL_SOURCE"
}
CordaRevocationChecker(externalCrlSource, softFail = true)
}
RevocationConfig.Mode.SOFT_FAIL -> CordaRevocationChecker(crlSource, softFail = true)
RevocationConfig.Mode.HARD_FAIL -> CordaRevocationChecker(crlSource, softFail = false)
}
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
pkixParams.addCertPathChecker(revocationChecker)
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(CertPathTrustManagerParameters(pkixParams))
return trustManagerFactory
}
/** /**
* Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target * Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target

View File

@ -5,6 +5,9 @@ import com.github.benmanes.caffeine.cache.LoadingCache
import net.corda.core.internal.readFully import net.corda.core.internal.readFully
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.nodeapi.internal.crypto.X509CertificateFactory import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.toSimpleString import net.corda.nodeapi.internal.crypto.toSimpleString
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
@ -12,60 +15,71 @@ import net.corda.nodeapi.internal.protonwrapper.netty.distributionPoints
import java.net.URI import java.net.URI
import java.security.cert.X509CRL import java.security.cert.X509CRL
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit import java.time.Duration
import javax.security.auth.x500.X500Principal import javax.security.auth.x500.X500Principal
/** /**
* [CrlSource] which downloads CRLs from the distribution points in the X509 certificate. * [CrlSource] which downloads CRLs from the distribution points in the X509 certificate and caches them.
*/ */
@Suppress("TooGenericExceptionCaught") @Suppress("TooGenericExceptionCaught")
class CertDistPointCrlSource : CrlSource { class CertDistPointCrlSource(cacheSize: Long = DEFAULT_CACHE_SIZE,
cacheExpiry: Duration = DEFAULT_CACHE_EXPIRY,
private val connectTimeout: Duration = DEFAULT_CONNECT_TIMEOUT,
private val readTimeout: Duration = DEFAULT_READ_TIMEOUT) : CrlSource {
companion object { companion object {
private val logger = contextLogger() private val logger = contextLogger()
// The default SSL handshake timeout is 60s (DEFAULT_SSL_HANDSHAKE_TIMEOUT). Considering there are 3 CRLs endpoints to check in a // The default SSL handshake timeout is 60s (DEFAULT_SSL_HANDSHAKE_TIMEOUT). Considering there are 3 CRLs endpoints to check in a
// node handshake, we want to keep the total timeout within that. // node handshake, we want to keep the total timeout within that.
private const val DEFAULT_CONNECT_TIMEOUT = 9_000 private val DEFAULT_CONNECT_TIMEOUT = 9.seconds
private const val DEFAULT_READ_TIMEOUT = 9_000 private val DEFAULT_READ_TIMEOUT = 9.seconds
private const val DEFAULT_CACHE_SIZE = 185L // Same default as the JDK (URICertStore) private const val DEFAULT_CACHE_SIZE = 185L // Same default as the JDK (URICertStore)
private const val DEFAULT_CACHE_EXPIRY = 5 * 60 * 1000L private val DEFAULT_CACHE_EXPIRY = 5.minutes
private val cache: LoadingCache<URI, X509CRL> = Caffeine.newBuilder() val SINGLETON = CertDistPointCrlSource(
.maximumSize(java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE)) cacheSize = java.lang.Long.getLong("net.corda.dpcrl.cache.size", DEFAULT_CACHE_SIZE),
.expireAfterWrite(java.lang.Long.getLong("net.corda.dpcrl.cache.expiry", DEFAULT_CACHE_EXPIRY), TimeUnit.MILLISECONDS) cacheExpiry = java.lang.Long.getLong("net.corda.dpcrl.cache.expiry")?.let(Duration::ofMillis) ?: DEFAULT_CACHE_EXPIRY,
.build(::retrieveCRL) connectTimeout = java.lang.Long.getLong("net.corda.dpcrl.connect.timeout")?.let(Duration::ofMillis) ?: DEFAULT_CONNECT_TIMEOUT,
readTimeout = java.lang.Long.getLong("net.corda.dpcrl.read.timeout")?.let(Duration::ofMillis) ?: DEFAULT_READ_TIMEOUT
)
}
private val connectTimeout = Integer.getInteger("net.corda.dpcrl.connect.timeout", DEFAULT_CONNECT_TIMEOUT) private val cache: LoadingCache<URI, X509CRL> = Caffeine.newBuilder()
private val readTimeout = Integer.getInteger("net.corda.dpcrl.read.timeout", DEFAULT_READ_TIMEOUT) .maximumSize(cacheSize)
.expireAfterWrite(cacheExpiry)
.build(::retrieveCRL)
private fun retrieveCRL(uri: URI): X509CRL { private fun retrieveCRL(uri: URI): X509CRL {
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
val bytes = try { val bytes = try {
val conn = uri.toURL().openConnection() val conn = uri.toURL().openConnection()
conn.connectTimeout = connectTimeout conn.connectTimeout = connectTimeout.toMillis().toInt()
conn.readTimeout = readTimeout conn.readTimeout = readTimeout.toMillis().toInt()
// Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes // Read all bytes first and then pass them into the CertificateFactory. This may seem unnecessary when generateCRL already takes
// in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException // in an InputStream, but the JDK implementation (sun.security.provider.X509Factory.engineGenerateCRL) converts any IOException
// into CRLException and drops the cause chain. // into CRLException and drops the cause chain.
conn.getInputStream().readFully() conn.getInputStream().readFully()
} catch (e: Exception) { } catch (e: Exception) {
if (logger.isDebugEnabled) { if (logger.isDebugEnabled) {
logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e) logger.debug("Unable to download CRL from $uri (${System.currentTimeMillis() - start}ms)", e)
}
throw e
} }
val duration = System.currentTimeMillis() - start throw e
val crl = try {
X509CertificateFactory().generateCRL(bytes.inputStream())
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Invalid CRL from $uri (${duration}ms)", e)
}
throw e
}
logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" }
return crl
} }
val duration = System.currentTimeMillis() - start
val crl = try {
X509CertificateFactory().generateCRL(bytes.inputStream())
} catch (e: Exception) {
if (logger.isDebugEnabled) {
logger.debug("Invalid CRL from $uri (${duration}ms)", e)
}
throw e
}
logger.debug { "CRL from $uri (${duration}ms): ${crl.toSimpleString()}" }
return crl
}
fun clearCache() {
cache.invalidateAll()
} }
override fun fetch(certificate: X509Certificate): Set<X509CRL> { override fun fetch(certificate: X509Certificate): Set<X509CRL> {

View File

@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.init import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
@ -161,11 +162,9 @@ class TlsDiffAlgorithmsTest(private val serverAlgo: String, private val clientAl
private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext { private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext {
return SSLContext.getInstance("TLS").apply { return SSLContext.getInstance("TLS").apply {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) val keyManagerFactory = keyManagerFactory(keyStore)
keyManagerFactory.init(keyStore)
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) val trustMgrFactory = trustManagerFactory(trustStore)
trustMgrFactory.init(trustStore)
val trustManagers = trustMgrFactory.trustManagers val trustManagers = trustMgrFactory.trustManagers
init(keyManagers, trustManagers, newSecureRandom()) init(keyManagers, trustManagers, newSecureRandom())
} }

View File

@ -4,7 +4,8 @@ import net.corda.core.crypto.newSecureRandom
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.init import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.junit.Ignore import org.junit.Ignore
import org.junit.Rule import org.junit.Rule
@ -18,7 +19,6 @@ import java.io.IOException
import java.net.InetAddress import java.net.InetAddress
import java.net.InetSocketAddress import java.net.InetSocketAddress
import javax.net.ssl.* import javax.net.ssl.*
import javax.net.ssl.SNIHostName
import kotlin.concurrent.thread import kotlin.concurrent.thread
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFalse import kotlin.test.assertFalse
@ -209,11 +209,9 @@ class TlsDiffProtocolsTest(private val serverAlgo: String, private val clientAlg
private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext { private fun createSslContext(keyStore: CertificateStore, trustStore: CertificateStore): SSLContext {
return SSLContext.getInstance("TLS").apply { return SSLContext.getInstance("TLS").apply {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) val keyManagerFactory = keyManagerFactory(keyStore)
keyManagerFactory.init(keyStore)
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) val trustMgrFactory = trustManagerFactory(trustStore)
trustMgrFactory.init(trustStore)
val trustManagers = trustMgrFactory.trustManagers val trustManagers = trustMgrFactory.trustManagers
init(keyManagers, trustManagers, newSecureRandom()) init(keyManagers, trustManagers, newSecureRandom())
} }

View File

@ -1,5 +1,6 @@
package net.corda.nodeapi.internal.protonwrapper.netty package net.corda.nodeapi.internal.protonwrapper.netty
import io.netty.util.concurrent.ImmediateExecutor
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
@ -8,10 +9,9 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
import net.corda.testing.internal.fixedCrlSource
import org.junit.Test import org.junit.Test
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SNIHostName import javax.net.ssl.SNIHostName
import javax.net.ssl.TrustManagerFactory
import kotlin.test.assertEquals import kotlin.test.assertEquals
class SSLHelperTest { class SSLHelperTest {
@ -20,15 +20,21 @@ class SSLHelperTest {
val legalName = CordaX500Name("Test", "London", "GB") val legalName = CordaX500Name("Test", "London", "GB")
val sslConfig = configureTestSSL(legalName) val sslConfig = configureTestSSL(legalName)
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
val keyStore = sslConfig.keyStore val trustManagerFactory = trustManagerFactoryWithRevocation(
keyManagerFactory.init(CertificateStore.fromFile(keyStore.path, keyStore.storePassword, keyStore.entryPassword, false)) sslConfig.trustStore.get(),
val trustStore = sslConfig.trustStore RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL),
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(CertificateStore.fromFile(trustStore.path, trustStore.storePassword, trustStore.entryPassword, false), RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL))) fixedCrlSource(emptySet())
)
val sslHandler = createClientSslHandler(NetworkHostAndPort("localhost", 1234), setOf(legalName), keyManagerFactory, trustManagerFactory) val sslHandler = createClientSslHandler(
NetworkHostAndPort("localhost", 1234),
setOf(legalName),
keyManagerFactory,
trustManagerFactory,
ImmediateExecutor.INSTANCE
)
val legalNameHash = SecureHash.sha256(legalName.toString()).toString().take(32).toLowerCase() val legalNameHash = SecureHash.sha256(legalName.toString()).toString().take(32).toLowerCase()
// These hardcoded values must not be changed, something is broken if you have to change these hardcoded values. // These hardcoded values must not be changed, something is broken if you have to change these hardcoded values.

View File

@ -2,15 +2,13 @@ package net.corda.nodeapi.internal.revocation
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.createDevNodeCa import net.corda.nodeapi.internal.DEV_INTERMEDIATE_CA
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.node.internal.network.CrlServer import net.corda.testing.node.internal.network.CrlServer
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.jce.provider.BouncyCastleProvider import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import java.math.BigInteger
class CertDistPointCrlSourceTest { class CertDistPointCrlSourceTest {
private lateinit var crlServer: CrlServer private lateinit var crlServer: CrlServer
@ -39,13 +37,14 @@ class CertDistPointCrlSourceTest {
assertThat(single().revokedCertificates).isNull() assertThat(single().revokedCertificates).isNull()
} }
val nodeCaCert = crlServer.replaceNodeCertDistPoint(createDevNodeCa(crlServer.intermediateCa, ALICE_NAME).certificate) crlSource.clearCache()
crlServer.revokedNodeCerts += listOf(BigInteger.ONE, BigInteger.TEN) crlServer.revokedIntermediateCerts += DEV_INTERMEDIATE_CA.certificate
with(crlSource.fetch(nodeCaCert)) { // Use a different cert to avoid the cache with(crlSource.fetch(crlServer.intermediateCa.certificate)) {
assertThat(size).isEqualTo(1) assertThat(size).isEqualTo(1)
val revokedCertificates = single().revokedCertificates val revokedCertificates = single().revokedCertificates
assertThat(revokedCertificates.map { it.serialNumber }).containsExactlyInAnyOrder(BigInteger.ONE, BigInteger.TEN) // This also tests clearCache() works.
assertThat(revokedCertificates.map { it.serialNumber }).containsExactly(DEV_INTERMEDIATE_CA.certificate.serialNumber)
} }
} }
} }

View File

@ -5,7 +5,7 @@ import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
import net.corda.nodeapi.internal.config.CertificateStore import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource import net.corda.testing.internal.fixedCrlSource
import org.bouncycastle.jcajce.provider.asymmetric.x509.CertificateFactory import org.bouncycastle.jcajce.provider.asymmetric.x509.CertificateFactory
import org.junit.Test import org.junit.Test
import java.math.BigInteger import java.math.BigInteger
@ -41,10 +41,8 @@ class CordaRevocationCheckerTest {
val resourceAsStream = javaClass.getResourceAsStream("/net/corda/nodeapi/internal/protonwrapper/netty/doorman.crl") val resourceAsStream = javaClass.getResourceAsStream("/net/corda/nodeapi/internal/protonwrapper/netty/doorman.crl")
val crl = CertificateFactory().engineGenerateCRL(resourceAsStream) as X509CRL val crl = CertificateFactory().engineGenerateCRL(resourceAsStream) as X509CRL
val crlSource = object : CrlSource { val checker = CordaRevocationChecker(
override fun fetch(certificate: X509Certificate): Set<X509CRL> = setOf(crl) crlSource = fixedCrlSource(setOf(crl)),
}
val checker = CordaRevocationChecker(crlSource,
softFail = true, softFail = true,
dateSource = { Date.from(date.atStartOfDay().toInstant(ZoneOffset.UTC)) } dateSource = { Date.from(date.atStartOfDay().toInstant(ZoneOffset.UTC)) }
) )

View File

@ -1,20 +1,16 @@
package net.corda.node.internal.artemis package net.corda.nodeapi.internal.revocation
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.utilities.days import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.node.internal.artemis.CertificateChainCheckPolicy.RevocationCheck import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.CertificateType import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509KeyStore
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.testing.core.createCRL
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x509.CRLReason
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.ExtensionsGenerator
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.IssuingDistributionPoint
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.junit.Before import org.junit.Before
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
@ -22,15 +18,18 @@ import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.junit.runners.Parameterized import org.junit.runners.Parameterized
import java.io.File import java.io.File
import java.security.KeyPair
import java.security.KeyStore import java.security.KeyStore
import java.security.PrivateKey import java.security.PrivateKey
import java.security.cert.CertificateException
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.* import java.util.*
import javax.net.ssl.X509TrustManager
import javax.security.auth.x500.X500Principal import javax.security.auth.x500.X500Principal
import kotlin.test.assertFails import kotlin.test.assertFailsWith
@RunWith(Parameterized::class) @RunWith(Parameterized::class)
class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) { class RevocationTest(private val revocationMode: RevocationConfig.Mode) {
companion object { companion object {
@JvmStatic @JvmStatic
@Parameterized.Parameters(name = "revocationMode = {0}") @Parameterized.Parameters(name = "revocationMode = {0}")
@ -45,8 +44,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
private lateinit var doormanCRL: File private lateinit var doormanCRL: File
private lateinit var tlsCRL: File private lateinit var tlsCRL: File
private val keyStore = KeyStore.getInstance("JKS") private lateinit var trustManager: X509TrustManager
private val trustStore = KeyStore.getInstance("JKS")
private val rootKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) private val rootKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
private val tlsCRLIssuerKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) private val tlsCRLIssuerKeyPair = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
@ -61,7 +59,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
private lateinit var tlsCert: X509Certificate private lateinit var tlsCert: X509Certificate
private val chain private val chain
get() = listOf(tlsCert, nodeCACert, doormanCert, rootCert).toTypedArray() get() = arrayOf(tlsCert, nodeCACert, doormanCert, rootCert)
@Before @Before
fun before() { fun before() {
@ -72,10 +70,18 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
rootCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=root"), rootKeyPair) rootCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=root"), rootKeyPair)
tlsCRLIssuerCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=issuer"), tlsCRLIssuerKeyPair) tlsCRLIssuerCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=issuer"), tlsCRLIssuerKeyPair)
val trustStore = KeyStore.getInstance("JKS")
trustStore.load(null, null) trustStore.load(null, null)
trustStore.setCertificateEntry("cordatlscrlsigner", tlsCRLIssuerCert) trustStore.setCertificateEntry("cordatlscrlsigner", tlsCRLIssuerCert)
trustStore.setCertificateEntry("cordarootca", rootCert) trustStore.setCertificateEntry("cordarootca", rootCert)
val trustManagerFactory = trustManagerFactoryWithRevocation(
CertificateStore.of(X509KeyStore(trustStore, "pass"), "pass", "pass"),
RevocationConfigImpl(revocationMode),
CertDistPointCrlSource()
)
trustManager = trustManagerFactory.trustManagers.single() as X509TrustManager
doormanCert = X509Utilities.createCertificate( doormanCert = X509Utilities.createCertificate(
CertificateType.INTERMEDIATE_CA, rootCert, rootKeyPair, X500Principal("CN=doorman"), doormanKeyPair.public, CertificateType.INTERMEDIATE_CA, rootCert, rootKeyPair, X500Principal("CN=doorman"), doormanKeyPair.public,
crlDistPoint = rootCRL.toURI().toString() crlDistPoint = rootCRL.toURI().toString()
@ -89,43 +95,34 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
) )
rootCRL.createCRL(rootCert, rootKeyPair.private, false) rootCRL.writeCRL(rootCert, rootKeyPair.private, false)
doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false) doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false)
tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true) tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true)
} }
private fun File.createCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) { private fun File.writeCRL(certificate: X509Certificate, privateKey: PrivateKey, indirect: Boolean, vararg revoked: X509Certificate) {
val builder = JcaX509v2CRLBuilder(certificate.subjectX500Principal, Date()) val crl = createCRL(
builder.setNextUpdate(Date.from(Date().toInstant() + 7.days)) CertificateAndKeyPair(certificate, KeyPair(certificate.publicKey, privateKey)),
builder.addExtension(Extension.issuingDistributionPoint, true, IssuingDistributionPoint(null, indirect, false)) revoked.asList(),
revoked.forEach { indirect = indirect
val extensionsGenerator = ExtensionsGenerator() )
extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(CRLReason.keyCompromise)) writeBytes(crl.encoded)
// Certificate issuer is required for indirect CRL
val certificateIssuerName = X500Name.getInstance(it.issuerX500Principal.encoded)
extensionsGenerator.addExtension(Extension.certificateIssuer, true, GeneralNames(GeneralName(certificateIssuerName)))
builder.addCRLEntry(it.serialNumber, Date(), extensionsGenerator.generate())
}
val holder = builder.build(JcaContentSignerBuilder("SHA256withECDSA").setProvider(Crypto.findProvider("BC")).build(privateKey))
outputStream().use { it.write(holder.encoded) }
} }
private fun assertFailsFor(vararg modes: RevocationConfig.Mode, block: () -> Unit) { private fun assertFailsFor(vararg modes: RevocationConfig.Mode) {
if (revocationMode in modes) assertFails(block) else block() if (revocationMode in modes) assertFailsWith(CertificateException::class, ::doRevocationCheck) else doRevocationCheck()
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
fun `ok with empty CRLs`() { fun `ok with empty CRLs`() {
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) doRevocationCheck()
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
fun `soft fail with revoked TLS certificate`() { fun `soft fail with revoked TLS certificate`() {
tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert) tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, tlsCert)
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) { assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -136,9 +133,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
) )
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -148,9 +143,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name("CN=unknown") crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name("CN=unknown")
) )
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -160,9 +153,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = tlsCRL.toURI().toString() crlDistPoint = tlsCRL.toURI().toString()
) )
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -172,18 +163,16 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=other"), otherKeyPair.public, CertificateType.TLS, nodeCACert, nodeCAKeyPair, X500Principal("CN=other"), otherKeyPair.public,
crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded) crlDistPoint = tlsCRL.toURI().toString(), crlIssuer = X500Name.getInstance(tlsCRLIssuerCert.issuerX500Principal.encoded)
) )
tlsCRL.createCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert) tlsCRL.writeCRL(tlsCRLIssuerCert, tlsCRLIssuerKeyPair.private, true, otherCert)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) doRevocationCheck()
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
fun `soft fail with revoked node CA certificate`() { fun `soft fail with revoked node CA certificate`() {
doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, nodeCACert) doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, nodeCACert)
assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL) { assertFailsFor(RevocationConfig.Mode.SOFT_FAIL, RevocationConfig.Mode.HARD_FAIL)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -193,9 +182,7 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/doorman" crlDistPoint = "http://unknown-host:10000/certificate-revocation-list/doorman"
) )
assertFailsFor(RevocationConfig.Mode.HARD_FAIL) { assertFailsFor(RevocationConfig.Mode.HARD_FAIL)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain)
}
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)
@ -205,8 +192,12 @@ class RevocationCheckTest(private val revocationMode: RevocationConfig.Mode) {
CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=other"), otherKeyPair.public, CertificateType.NODE_CA, doormanCert, doormanKeyPair, X500Principal("CN=other"), otherKeyPair.public,
crlDistPoint = doormanCRL.toURI().toString() crlDistPoint = doormanCRL.toURI().toString()
) )
doormanCRL.createCRL(doormanCert, doormanKeyPair.private, false, otherCert) doormanCRL.writeCRL(doormanCert, doormanKeyPair.private, false, otherCert)
RevocationCheck(revocationMode).createCheck(keyStore, trustStore).checkCertificateChain(chain) doRevocationCheck()
}
private fun doRevocationCheck() {
trustManager.checkClientTrusted(chain, "ECDHE_ECDSA")
} }
} }

View File

@ -270,8 +270,6 @@ tasks.register('integrationTest', Test) {
testClassesDirs = sourceSets.integrationTest.output.classesDirs testClassesDirs = sourceSets.integrationTest.output.classesDirs
classpath = sourceSets.integrationTest.runtimeClasspath classpath = sourceSets.integrationTest.runtimeClasspath
maxParallelForks = (System.env.CORDA_NODE_INT_TESTING_FORKS == null) ? 1 : "$System.env.CORDA_NODE_INT_TESTING_FORKS".toInteger() maxParallelForks = (System.env.CORDA_NODE_INT_TESTING_FORKS == null) ? 1 : "$System.env.CORDA_NODE_INT_TESTING_FORKS".toInteger()
// CertificateRevocationListNodeTests
systemProperty 'net.corda.dpcrl.connect.timeout', '4000'
} }
tasks.register('slowIntegrationTest', Test) { tasks.register('slowIntegrationTest', Test) {

View File

@ -15,12 +15,15 @@ import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionResult import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionResult
import net.corda.nodeapi.internal.protonwrapper.netty.init import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.initialiseTrustStoreAndEnableCrlChecking import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.testing.internal.fixedCrlSource
import org.junit.Assume.assumeFalse import org.junit.Assume.assumeFalse
import org.junit.Before import org.junit.Before
import org.junit.Rule import org.junit.Rule
@ -98,11 +101,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) {
override val maxMessageSize: Int = MAX_MESSAGE_SIZE override val maxMessageSize: Int = MAX_MESSAGE_SIZE
} }
serverKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) serverKeyManagerFactory = keyManagerFactory(keyStore)
serverTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
serverKeyManagerFactory.init(keyStore) serverTrustManagerFactory = trustManagerFactoryWithRevocation(
serverTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(serverAmqpConfig.trustStore, serverAmqpConfig.revocationConfig)) serverAmqpConfig.trustStore,
RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL),
fixedCrlSource(emptySet())
)
} }
private fun setupClientCertificates() { private fun setupClientCertificates() {
@ -129,11 +134,13 @@ class AMQPClientSslErrorsTest(@Suppress("unused") private val iteration: Int) {
override val sslHandshakeTimeout: Duration = 3.seconds override val sslHandshakeTimeout: Duration = 3.seconds
} }
clientKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) clientKeyManagerFactory = keyManagerFactory(keyStore)
clientTrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
clientKeyManagerFactory.init(keyStore) clientTrustManagerFactory = trustManagerFactoryWithRevocation(
clientTrustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(clientAmqpConfig.trustStore, clientAmqpConfig.revocationConfig)) clientAmqpConfig.trustStore,
RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL),
fixedCrlSource(emptySet())
)
} }
@Test(timeout = 300_000) @Test(timeout = 300_000)

View File

@ -1,3 +1,5 @@
@file:Suppress("LongParameterList")
package net.corda.node.amqp package net.corda.node.amqp
import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.doReturn
@ -5,10 +7,10 @@ import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div import net.corda.core.internal.div
import net.corda.core.internal.rootCause
import net.corda.core.internal.times import net.corda.core.internal.times
import net.corda.core.toFuture
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.coretesting.internal.rigorousMock import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
@ -18,64 +20,68 @@ import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.config.CertificateStoreSupplier import net.corda.nodeapi.internal.config.CertificateStoreSupplier
import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_CA
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.nodeapi.internal.protonwrapper.netty.ConnectionChange
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.CHARLIE_NAME import net.corda.testing.core.CHARLIE_NAME
import net.corda.testing.core.MAX_MESSAGE_SIZE import net.corda.testing.core.MAX_MESSAGE_SIZE
import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.testing.node.internal.network.CrlServer import net.corda.testing.node.internal.network.CrlServer
import net.corda.testing.node.internal.network.CrlServer.Companion.EMPTY_CRL import net.corda.testing.node.internal.network.CrlServer.Companion.EMPTY_CRL
import net.corda.testing.node.internal.network.CrlServer.Companion.FORBIDDEN_CRL
import net.corda.testing.node.internal.network.CrlServer.Companion.NODE_CRL import net.corda.testing.node.internal.network.CrlServer.Companion.NODE_CRL
import net.corda.testing.node.internal.network.CrlServer.Companion.withCrlDistPoint import net.corda.testing.node.internal.network.CrlServer.Companion.withCrlDistPoint
import org.apache.activemq.artemis.api.core.QueueConfiguration import org.apache.activemq.artemis.api.core.QueueConfiguration
import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.RoutingType
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import org.bouncycastle.jce.provider.BouncyCastleProvider import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.rules.TemporaryFolder import org.junit.rules.TemporaryFolder
import java.net.SocketTimeoutException import java.io.Closeable
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.time.Duration import java.time.Duration
import java.util.concurrent.BlockingQueue
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.assertEquals import java.util.stream.IntStream
@Suppress("LongParameterList") abstract class AbstractServerRevocationTest {
class CertificateRevocationListNodeTests {
@Rule @Rule
@JvmField @JvmField
val temporaryFolder = TemporaryFolder() val temporaryFolder = TemporaryFolder()
private val portAllocation = incrementalPortAllocation() private val portAllocation = incrementalPortAllocation()
private val serverPort = portAllocation.nextPort() protected val serverPort = portAllocation.nextPort()
private lateinit var crlServer: CrlServer protected lateinit var crlServer: CrlServer
private lateinit var amqpServer: AMQPServer private val amqpClients = ArrayList<AMQPClient>()
private lateinit var amqpClient: AMQPClient
private abstract class AbstractNodeConfiguration : NodeConfiguration protected lateinit var defaultCrlDistPoints: CrlDistPoints
protected abstract class AbstractNodeConfiguration : NodeConfiguration
companion object { companion object {
private val unreachableIpCounter = AtomicInteger(1) private val unreachableIpCounter = AtomicInteger(1)
private val crlConnectTimeout = Duration.ofMillis(System.getProperty("net.corda.dpcrl.connect.timeout").toLong()) val crlConnectTimeout = 2.seconds
/** /**
* Use this method to get a unqiue unreachable IP address. Subsequent uses of the same IP for connection timeout testing purposes * Use this method to get a unqiue unreachable IP address. Subsequent uses of the same IP for connection timeout testing purposes
* may not work as the OS process may cache the timeout result. * may not work as the OS process may cache the timeout result.
*/ */
private fun newUnreachableIpAddress(): String { private fun newUnreachableIpAddress(): NetworkHostAndPort {
check(unreachableIpCounter.get() != 255) check(unreachableIpCounter.get() != 255)
return "10.255.255.${unreachableIpCounter.getAndIncrement()}" return NetworkHostAndPort("10.255.255", unreachableIpCounter.getAndIncrement())
} }
} }
@ -85,252 +91,190 @@ class CertificateRevocationListNodeTests {
Crypto.findProvider(BouncyCastleProvider.PROVIDER_NAME) Crypto.findProvider(BouncyCastleProvider.PROVIDER_NAME)
crlServer = CrlServer(NetworkHostAndPort("localhost", 0)) crlServer = CrlServer(NetworkHostAndPort("localhost", 0))
crlServer.start() crlServer.start()
defaultCrlDistPoints = CrlDistPoints(crlServer.hostAndPort)
} }
@After @After
fun tearDown() { fun tearDown() {
if (::amqpClient.isInitialized) { amqpClients.parallelStream().forEach(AMQPClient::close)
amqpClient.close()
}
if (::amqpServer.isInitialized) {
amqpServer.close()
}
if (::crlServer.isInitialized) { if (::crlServer.isInitialized) {
crlServer.close() crlServer.close()
} }
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection works and soft fail is enabled`() { fun `connection succeeds when soft fail is enabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
expectedConnectStatus = true expectedConnectedStatus = true
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection works and soft fail is disabled`() { fun `connection succeeds when soft fail is disabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = false, crlCheckSoftFail = false,
expectedConnectStatus = true expectedConnectedStatus = true
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when client's certificate is revoked and soft fail is enabled`() { fun `connection fails when client's certificate is revoked and soft fail is enabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
revokeClientCert = true, revokeClientCert = true,
expectedConnectStatus = false expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when client's certificate is revoked and soft fail is disabled`() { fun `connection fails when client's certificate is revoked and soft fail is disabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = false, crlCheckSoftFail = false,
revokeClientCert = true, revokeClientCert = true,
expectedConnectStatus = false expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when servers's certificate is revoked and soft fail is enabled`() { fun `connection fails when server's certificate is revoked and soft fail is enabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
revokeServerCert = true, revokeServerCert = true,
expectedConnectStatus = false expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when servers's certificate is revoked and soft fail is disabled`() { fun `connection fails when server's certificate is revoked and soft fail is disabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = false, crlCheckSoftFail = false,
revokeServerCert = true, revokeServerCert = true,
expectedConnectStatus = false expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL cannot be obtained and soft fail is enabled`() { fun `connection succeeds when CRL cannot be obtained and soft fail is enabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/invalid.crl", clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"),
expectedConnectStatus = true expectedConnectedStatus = true
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when CRL cannot be obtained and soft fail is disabled`() { fun `connection fails when CRL cannot be obtained and soft fail is disabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = false, crlCheckSoftFail = false,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/invalid.crl", clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = "non-existent.crl"),
expectedConnectStatus = false expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL is not defined and soft fail is enabled`() { fun `connection succeeds when CRL is not defined for node CA cert and soft fail is enabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
nodeCrlDistPoint = null, clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null),
expectedConnectStatus = true expectedConnectedStatus = true
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when CRL is not defined and soft fail is disabled`() { fun `connection fails when CRL is not defined for node CA cert and soft fail is disabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = false, crlCheckSoftFail = false,
nodeCrlDistPoint = null, clientCrlDistPoints = defaultCrlDistPoints.copy(nodeCa = null),
expectedConnectStatus = false expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL retrieval is forbidden and soft fail is enabled`() { fun `connection succeeds when CRL is not defined for TLS cert and soft fail is enabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL", clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null),
expectedConnectStatus = true expectedConnectedStatus = true
) )
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() { fun `connection fails when CRL is not defined for TLS cert and soft fail is disabled`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = false,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl", clientCrlDistPoints = defaultCrlDistPoints.copy(tls = null),
sslHandshakeTimeout = crlConnectTimeout * 3, expectedConnectedStatus = false
expectedConnectStatus = true
) )
val timeoutExceptions = (amqpServer.softFailExceptions + amqpClient.softFailExceptions)
.map { it.rootCause }
.filterIsInstance<SocketTimeoutException>()
assertThat(timeoutExceptions).isNotEmpty
} }
@Test(timeout=300_000) @Test(timeout=300_000)
fun `AMQP server connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() { fun `connection succeeds when CRL endpoint is unreachable, soft fail is enabled and CRL timeouts are within SSL handshake timeout`() {
verifyAMQPConnection( verifyConnection(
crlCheckSoftFail = true,
sslHandshakeTimeout = crlConnectTimeout * 4,
clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()),
expectedConnectedStatus = true
)
}
@Test(timeout=300_000)
fun `connection fails when CRL endpoint is unreachable, despite soft fail enabled, when CRL timeouts are not within SSL handshake timeout`() {
verifyConnection(
crlCheckSoftFail = true, crlCheckSoftFail = true,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout / 2, sslHandshakeTimeout = crlConnectTimeout / 2,
expectedConnectStatus = false clientCrlDistPoints = defaultCrlDistPoints.copy(crlServerAddress = newUnreachableIpAddress()),
expectedConnectedStatus = false
) )
} }
@Test(timeout=300_000) @Test(timeout = 300_000)
fun `verify CRL algorithms`() { fun `influx of new clients during CRL endpoint downtime does not cause existing connections to drop`() {
val crl = crlServer.createRevocationList( val serverCrlSource = CertDistPointCrlSource()
"SHA256withECDSA", // Start the server and verify the first client has connected
crlServer.rootCa, val firstClientConnectionChangeStatus = verifyConnection(
EMPTY_CRL, crlCheckSoftFail = true,
true, crlSource = serverCrlSource,
emptyList() // In general, N remoting threads will naturally support N-1 new handshaking clients plus one thread for heartbeating with
// existing clients. The trick is to make sure at least N new clients are also supported.
remotingThreads = 2,
expectedConnectedStatus = true
) )
// This should pass.
crl.verify(crlServer.rootCa.keyPair.public)
// Try changing the algorithm to EC will fail. // Now simulate the CRL endpoint becoming very slow/unreachable
assertThatIllegalArgumentException().isThrownBy { crlServer.delay = 10.minutes
crlServer.createRevocationList( // And pretend enough time has elapsed that the cached CRLs have expired and need downloading again
"EC", serverCrlSource.clearCache()
crlServer.rootCa,
EMPTY_CRL, // Now a bunch of new clients have arrived and want to handshake with the server, which will potentially cause the server's Netty
true, // threads to be tied up in trying to download the CRLs.
emptyList() IntStream.range(0, 2).parallel().forEach { clientIndex ->
val (newClient, _) = createAMQPClient(
serverPort,
crlCheckSoftFail = true,
legalName = CordaX500Name("NewClient$clientIndex", "London", "GB"),
crlDistPoints = defaultCrlDistPoints
) )
}.withMessage("Unknown signature type requested: EC") newClient.start()
}
// Make sure there are no further connection change updates, i.e. the first client stays connected throughout this whole saga
assertThat(firstClientConnectionChangeStatus.poll(30, TimeUnit.SECONDS)).isNull()
} }
@Test(timeout = 300_000) protected abstract fun verifyConnection(crlCheckSoftFail: Boolean,
fun `Artemis server connection succeeds with soft fail CRL check`() { crlSource: CertDistPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout),
verifyArtemisConnection( sslHandshakeTimeout: Duration? = null,
crlCheckSoftFail = true, remotingThreads: Int? = null,
crlCheckArtemisServer = true, clientCrlDistPoints: CrlDistPoints = defaultCrlDistPoints,
expectedStatus = MessageStatus.Acknowledged revokeClientCert: Boolean = false,
) revokeServerCert: Boolean = false,
} expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange>
@Test(timeout = 300_000) protected fun createAMQPClient(targetPort: Int,
fun `Artemis server connection succeeds with hard fail CRL check`() { crlCheckSoftFail: Boolean,
verifyArtemisConnection( legalName: CordaX500Name,
crlCheckSoftFail = false, crlDistPoints: CrlDistPoints): Pair<AMQPClient, X509Certificate> {
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged
)
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with soft fail CRL check on unavailable URL`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL"
)
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with soft fail CRL check on unreachable URL if CRL timeout is within SSL handshake timeout`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Acknowledged,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout * 3
)
}
@Test(timeout = 300_000)
fun `Artemis server connection fails with soft fail CRL check on unreachable URL if CRL timeout is not within SSL handshake timeout`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedConnected = false,
nodeCrlDistPoint = "http://${newUnreachableIpAddress()}/crl/unreachable.crl",
sslHandshakeTimeout = crlConnectTimeout / 2
)
}
@Test(timeout = 300_000)
fun `Artemis server connection fails with hard fail CRL check on unavailable URL`() {
verifyArtemisConnection(
crlCheckSoftFail = false,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Rejected,
nodeCrlDistPoint = "http://${crlServer.hostAndPort}/crl/$FORBIDDEN_CRL"
)
}
@Test(timeout = 300_000)
fun `Artemis server connection fails with soft fail CRL check on revoked node certificate`() {
verifyArtemisConnection(
crlCheckSoftFail = true,
crlCheckArtemisServer = true,
expectedStatus = MessageStatus.Rejected,
revokedNodeCert = true
)
}
@Test(timeout = 300_000)
fun `Artemis server connection succeeds with disabled CRL check on revoked node certificate`() {
verifyArtemisConnection(
crlCheckSoftFail = false,
crlCheckArtemisServer = false,
expectedStatus = MessageStatus.Acknowledged,
revokedNodeCert = true
)
}
private fun createAMQPClient(targetPort: Int,
crlCheckSoftFail: Boolean,
legalName: CordaX500Name,
nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
tlsCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$EMPTY_CRL",
maxMessageSize: Int = MAX_MESSAGE_SIZE): X509Certificate {
val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation
val certificatesDirectory = baseDirectory / "certificates" val certificatesDirectory = baseDirectory / "certificates"
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory) val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory)
@ -344,31 +288,128 @@ class CertificateRevocationListNodeTests {
doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail
} }
clientConfig.configureWithDevSSLCertificate() clientConfig.configureWithDevSSLCertificate()
val nodeCert = recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, tlsCrlDistPoint) val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer)
val keyStore = clientConfig.p2pSslOptions.keyStore.get() val keyStore = clientConfig.p2pSslOptions.keyStore.get()
val amqpConfig = object : AMQPConfiguration { val amqpConfig = object : AMQPConfiguration {
override val keyStore = keyStore override val keyStore = keyStore
override val trustStore = clientConfig.p2pSslOptions.trustStore.get() override val trustStore = clientConfig.p2pSslOptions.trustStore.get()
override val maxMessageSize: Int = maxMessageSize override val maxMessageSize: Int = MAX_MESSAGE_SIZE
override val trace: Boolean = true
} }
amqpClient = AMQPClient( val amqpClient = AMQPClient(
listOf(NetworkHostAndPort("localhost", targetPort)), listOf(NetworkHostAndPort("localhost", targetPort)),
setOf(CHARLIE_NAME), setOf(CHARLIE_NAME),
amqpConfig, amqpConfig,
threadPoolName = legalName.organisation nettyThreading = AMQPClient.NettyThreading.NonShared(legalName.organisation),
distPointCrlSource = CertDistPointCrlSource(connectTimeout = crlConnectTimeout)
) )
amqpClients += amqpClient
return Pair(amqpClient, nodeCert)
}
return nodeCert protected fun AMQPClient.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange> {
val connectionChangeStatus = LinkedBlockingQueue<ConnectionChange>()
onConnection.subscribe { connectionChangeStatus.add(it) }
start()
assertThat(connectionChangeStatus.take().connected).isEqualTo(expectedConnectedStatus)
return connectionChangeStatus
}
protected data class CrlDistPoints(val crlServerAddress: NetworkHostAndPort,
val nodeCa: String? = NODE_CRL,
val tls: String? = EMPTY_CRL) {
private val nodeCaCertCrlDistPoint: String? get() = nodeCa?.let { "http://$crlServerAddress/crl/$it" }
private val tlsCertCrlDistPoint: String? get() = tls?.let { "http://$crlServerAddress/crl/$it" }
fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier,
p2pSslConfiguration: MutualSslConfiguration,
crlServer: CrlServer): X509Certificate {
val nodeKeyStore = signingCertificateStore.get()
val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_CA, nodeKeyStore.entryPassword) }
val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCertCrlDistPoint)
val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) +
nodeKeyStore.query { getCertificateChain(CORDA_CLIENT_CA) }.drop(2)
nodeKeyStore.update {
internal.deleteEntry(CORDA_CLIENT_CA)
}
nodeKeyStore.update {
setPrivateKey(CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword)
}
val sslKeyStore = p2pSslConfiguration.keyStore.get()
val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(CORDA_CLIENT_TLS, sslKeyStore.entryPassword) }
val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCertCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal)
val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) +
sslKeyStore.query { getCertificateChain(CORDA_CLIENT_TLS) }.drop(3)
sslKeyStore.update {
internal.deleteEntry(CORDA_CLIENT_TLS)
}
sslKeyStore.update {
setPrivateKey(CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword)
}
return newNodeCert
}
}
}
class AMQPServerRevocationTest : AbstractServerRevocationTest() {
private lateinit var amqpServer: AMQPServer
@After
fun shutDown() {
if (::amqpServer.isInitialized) {
amqpServer.close()
}
}
override fun verifyConnection(crlCheckSoftFail: Boolean,
crlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?,
remotingThreads: Int?,
clientCrlDistPoints: CrlDistPoints,
revokeClientCert: Boolean,
revokeServerCert: Boolean,
expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange> {
val serverCert = createAMQPServer(
serverPort,
CHARLIE_NAME,
crlCheckSoftFail,
defaultCrlDistPoints,
crlSource,
sslHandshakeTimeout,
remotingThreads
)
if (revokeServerCert) {
crlServer.revokedNodeCerts.add(serverCert)
}
amqpServer.start()
amqpServer.onReceive.subscribe {
it.complete(true)
}
val (client, clientCert) = createAMQPClient(
serverPort,
crlCheckSoftFail = crlCheckSoftFail,
legalName = ALICE_NAME,
crlDistPoints = clientCrlDistPoints
)
if (revokeClientCert) {
crlServer.revokedNodeCerts.add(clientCert)
}
return client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus)
} }
private fun createAMQPServer(port: Int, private fun createAMQPServer(port: Int,
legalName: CordaX500Name, legalName: CordaX500Name,
crlCheckSoftFail: Boolean, crlCheckSoftFail: Boolean,
nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL", crlDistPoints: CrlDistPoints,
tlsCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$EMPTY_CRL", distPointCrlSource: CertDistPointCrlSource,
maxMessageSize: Int = MAX_MESSAGE_SIZE, sslHandshakeTimeout: Duration?,
sslHandshakeTimeout: Duration? = null): X509Certificate { remotingThreads: Int?): X509Certificate {
check(!::amqpServer.isInitialized) check(!::amqpServer.isInitialized)
val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation
val certificatesDirectory = baseDirectory / "certificates" val certificatesDirectory = baseDirectory / "certificates"
@ -382,92 +423,103 @@ class CertificateRevocationListNodeTests {
doReturn(signingCertificateStore).whenever(it).signingCertificateStore doReturn(signingCertificateStore).whenever(it).signingCertificateStore
} }
serverConfig.configureWithDevSSLCertificate() serverConfig.configureWithDevSSLCertificate()
val nodeCert = recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, tlsCrlDistPoint) val serverCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer)
val keyStore = serverConfig.p2pSslOptions.keyStore.get() val keyStore = serverConfig.p2pSslOptions.keyStore.get()
val amqpConfig = object : AMQPConfiguration { val amqpConfig = object : AMQPConfiguration {
override val keyStore = keyStore override val keyStore = keyStore
override val trustStore = serverConfig.p2pSslOptions.trustStore.get() override val trustStore = serverConfig.p2pSslOptions.trustStore.get()
override val revocationConfig = crlCheckSoftFail.toRevocationConfig() override val revocationConfig = crlCheckSoftFail.toRevocationConfig()
override val maxMessageSize: Int = maxMessageSize override val maxMessageSize: Int = MAX_MESSAGE_SIZE
override val sslHandshakeTimeout: Duration = sslHandshakeTimeout ?: super.sslHandshakeTimeout override val sslHandshakeTimeout: Duration = sslHandshakeTimeout ?: super.sslHandshakeTimeout
} }
amqpServer = AMQPServer("0.0.0.0", port, amqpConfig, threadPoolName = legalName.organisation) amqpServer = AMQPServer(
return nodeCert "0.0.0.0",
} port,
amqpConfig,
private fun recreateNodeCaAndTlsCertificates(signingCertificateStore: CertificateStoreSupplier, threadPoolName = legalName.organisation,
p2pSslConfiguration: MutualSslConfiguration, distPointCrlSource = distPointCrlSource,
nodeCaCrlDistPoint: String?, remotingThreads = remotingThreads
tlsCrlDistPoint: String?): X509Certificate {
val nodeKeyStore = signingCertificateStore.get()
val (nodeCert, nodeKeys) = nodeKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA, nodeKeyStore.entryPassword) }
val newNodeCert = crlServer.replaceNodeCertDistPoint(nodeCert, nodeCaCrlDistPoint)
val nodeCertChain = listOf(newNodeCert, crlServer.intermediateCa.certificate) +
nodeKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_CA) }.drop(2)
nodeKeyStore.update {
internal.deleteEntry(X509Utilities.CORDA_CLIENT_CA)
}
nodeKeyStore.update {
setPrivateKey(X509Utilities.CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain, nodeKeyStore.entryPassword)
}
val sslKeyStore = p2pSslConfiguration.keyStore.get()
val (tlsCert, tlsKeys) = sslKeyStore.query { getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_TLS, sslKeyStore.entryPassword) }
val newTlsCert = tlsCert.withCrlDistPoint(nodeKeys, tlsCrlDistPoint, crlServer.rootCa.certificate.subjectX500Principal)
val sslCertChain = listOf(newTlsCert, newNodeCert, crlServer.intermediateCa.certificate) +
sslKeyStore.query { getCertificateChain(X509Utilities.CORDA_CLIENT_TLS) }.drop(3)
sslKeyStore.update {
internal.deleteEntry(X509Utilities.CORDA_CLIENT_TLS)
}
sslKeyStore.update {
setPrivateKey(X509Utilities.CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain, sslKeyStore.entryPassword)
}
return newNodeCert
}
private fun verifyAMQPConnection(crlCheckSoftFail: Boolean,
nodeCrlDistPoint: String? = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
revokeServerCert: Boolean = false,
revokeClientCert: Boolean = false,
sslHandshakeTimeout: Duration? = null,
expectedConnectStatus: Boolean) {
val serverCert = createAMQPServer(
serverPort,
CHARLIE_NAME,
crlCheckSoftFail = crlCheckSoftFail,
nodeCrlDistPoint = nodeCrlDistPoint,
sslHandshakeTimeout = sslHandshakeTimeout
) )
if (revokeServerCert) { return serverCert
crlServer.revokedNodeCerts.add(serverCert.serialNumber) }
}
class ArtemisServerRevocationTest : AbstractServerRevocationTest() {
private lateinit var artemisNode: ArtemisNode
private var crlCheckArtemisServer = true
@After
fun shutDown() {
if (::artemisNode.isInitialized) {
artemisNode.close()
} }
amqpServer.start() }
amqpServer.onReceive.subscribe {
it.complete(true) @Test(timeout = 300_000)
} fun `connection succeeds with disabled CRL check on revoked node certificate`() {
val clientCert = createAMQPClient( crlCheckArtemisServer = false
verifyConnection(
crlCheckSoftFail = false,
revokeClientCert = true,
expectedConnectedStatus = true
)
}
override fun verifyConnection(crlCheckSoftFail: Boolean,
crlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?,
remotingThreads: Int?,
clientCrlDistPoints: CrlDistPoints,
revokeClientCert: Boolean,
revokeServerCert: Boolean,
expectedConnectedStatus: Boolean): BlockingQueue<ConnectionChange> {
val (client, clientCert) = createAMQPClient(
serverPort, serverPort,
crlCheckSoftFail = crlCheckSoftFail, crlCheckSoftFail = true,
legalName = ALICE_NAME, legalName = ALICE_NAME,
nodeCrlDistPoint = nodeCrlDistPoint crlDistPoints = clientCrlDistPoints
) )
if (revokeClientCert) { if (revokeClientCert) {
crlServer.revokedNodeCerts.add(clientCert.serialNumber) crlServer.revokedNodeCerts.add(clientCert)
} }
val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.start() val nodeCert = startArtemisNode(
val serverConnect = serverConnected.get() CHARLIE_NAME,
assertThat(serverConnect.connected).isEqualTo(expectedConnectStatus) crlCheckSoftFail,
defaultCrlDistPoints,
crlSource,
sslHandshakeTimeout,
remotingThreads
)
if (revokeServerCert) {
crlServer.revokedNodeCerts.add(nodeCert)
}
val queueName = "${P2P_PREFIX}Test"
artemisNode.client.started!!.session.createQueue(
QueueConfiguration(queueName).setRoutingType(RoutingType.ANYCAST).setAddress(queueName).setDurable(true)
)
val clientConnectionChangeStatus = client.waitForInitialConnectionAndCaptureChanges(expectedConnectedStatus)
if (expectedConnectedStatus) {
val msg = client.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap())
client.write(msg)
assertThat(msg.onComplete.get()).isEqualTo(MessageStatus.Acknowledged)
}
return clientConnectionChangeStatus
} }
private fun createArtemisServerAndClient(legalName: CordaX500Name, private fun startArtemisNode(legalName: CordaX500Name,
crlCheckSoftFail: Boolean, crlCheckSoftFail: Boolean,
crlCheckArtemisServer: Boolean, crlDistPoints: CrlDistPoints,
nodeCrlDistPoint: String, distPointCrlSource: CertDistPointCrlSource,
sslHandshakeTimeout: Duration?): Pair<ArtemisMessagingServer, ArtemisMessagingClient> { sslHandshakeTimeout: Duration?,
val baseDirectory = temporaryFolder.root.toPath() / "artemis" remotingThreads: Int?): X509Certificate {
check(!::artemisNode.isInitialized)
val baseDirectory = temporaryFolder.root.toPath() / legalName.organisation
val certificatesDirectory = baseDirectory / "certificates" val certificatesDirectory = baseDirectory / "certificates"
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory) val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, sslHandshakeTimeout = sslHandshakeTimeout) val p2pSslConfiguration = CertificateStoreStubs.P2P.withCertificatesDirectory(certificatesDirectory, sslHandshakeTimeout = sslHandshakeTimeout)
@ -483,62 +535,34 @@ class CertificateRevocationListNodeTests {
doReturn(crlCheckArtemisServer).whenever(it).crlCheckArtemisServer doReturn(crlCheckArtemisServer).whenever(it).crlCheckArtemisServer
} }
artemisConfig.configureWithDevSSLCertificate() artemisConfig.configureWithDevSSLCertificate()
recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, nodeCrlDistPoint, null) val nodeCert = crlDistPoints.recreateNodeCaAndTlsCertificates(signingCertificateStore, p2pSslConfiguration, crlServer)
val server = ArtemisMessagingServer( val server = ArtemisMessagingServer(
artemisConfig, artemisConfig,
artemisConfig.p2pAddress, artemisConfig.p2pAddress,
MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE,
threadPoolName = "${legalName.organisation}-server", threadPoolName = "${legalName.organisation}-server",
trace = true trace = true,
distPointCrlSource = distPointCrlSource,
remotingThreads = remotingThreads
) )
val client = ArtemisMessagingClient( val client = ArtemisMessagingClient(
artemisConfig.p2pSslOptions, artemisConfig.p2pSslOptions,
artemisConfig.p2pAddress, artemisConfig.p2pAddress,
MAX_MESSAGE_SIZE, MAX_MESSAGE_SIZE,
threadPoolName = "${legalName.organisation}-client", threadPoolName = "${legalName.organisation}-client"
trace = true
) )
server.start() server.start()
client.start() client.start()
return server to client val artemisNode = ArtemisNode(server, client)
this.artemisNode = artemisNode
return nodeCert
} }
private fun verifyArtemisConnection(crlCheckSoftFail: Boolean, private class ArtemisNode(val server: ArtemisMessagingServer, val client: ArtemisMessagingClient) : Closeable {
crlCheckArtemisServer: Boolean, override fun close() {
expectedConnected: Boolean = true, client.stop()
expectedStatus: MessageStatus? = null, server.close()
revokedNodeCert: Boolean = false,
nodeCrlDistPoint: String = "http://${crlServer.hostAndPort}/crl/$NODE_CRL",
sslHandshakeTimeout: Duration? = null) {
val queueName = P2P_PREFIX + "Test"
val (artemisServer, artemisClient) = createArtemisServerAndClient(
CHARLIE_NAME,
crlCheckSoftFail,
crlCheckArtemisServer,
nodeCrlDistPoint,
sslHandshakeTimeout
)
artemisServer.use {
artemisClient.started!!.session.createQueue(
QueueConfiguration(queueName).setRoutingType(RoutingType.ANYCAST).setAddress(queueName).setDurable(true)
)
val nodeCert = createAMQPClient(serverPort, true, ALICE_NAME, nodeCrlDistPoint)
if (revokedNodeCert) {
crlServer.revokedNodeCerts.add(nodeCert.serialNumber)
}
val clientConnected = amqpClient.onConnection.toFuture()
amqpClient.start()
val clientConnect = clientConnected.get()
assertThat(clientConnect.connected).isEqualTo(expectedConnected)
if (expectedConnected) {
val msg = amqpClient.createMessage("Test".toByteArray(), queueName, CHARLIE_NAME.toString(), emptyMap())
amqpClient.write(msg)
assertEquals(expectedStatus, msg.onComplete.get())
}
artemisClient.stop()
} }
} }
} }

View File

@ -4,12 +4,15 @@ import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import io.netty.channel.EventLoopGroup import io.netty.channel.EventLoopGroup
import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.nio.NioEventLoopGroup
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.crypto.newSecureRandom import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div import net.corda.core.internal.div
import net.corda.core.toFuture import net.corda.core.toFuture
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.node.services.messaging.ArtemisMessagingServer import net.corda.node.services.messaging.ArtemisMessagingServer
@ -23,7 +26,9 @@ import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.nodeapi.internal.protonwrapper.netty.init import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
@ -31,9 +36,6 @@ import net.corda.testing.core.CHARLIE_NAME
import net.corda.testing.core.MAX_MESSAGE_SIZE import net.corda.testing.core.MAX_MESSAGE_SIZE
import net.corda.testing.driver.internal.incrementalPortAllocation import net.corda.testing.driver.internal.incrementalPortAllocation
import net.corda.testing.internal.createDevIntermediateCaCertPath import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.corda.coretesting.internal.rigorousMock
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
import net.corda.nodeapi.internal.protonwrapper.netty.toRevocationConfig
import org.apache.activemq.artemis.api.core.QueueConfiguration import org.apache.activemq.artemis.api.core.QueueConfiguration
import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.RoutingType
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
@ -44,7 +46,11 @@ import org.junit.Test
import org.junit.rules.TemporaryFolder import org.junit.rules.TemporaryFolder
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import javax.net.ssl.* import javax.net.ssl.SSLContext
import javax.net.ssl.SSLHandshakeException
import javax.net.ssl.SSLParameters
import javax.net.ssl.SSLServerSocket
import javax.net.ssl.SSLSocket
import kotlin.concurrent.thread import kotlin.concurrent.thread
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
@ -146,15 +152,10 @@ class ProtonWrapperTests {
sslConfig.keyStore.get(true).also { it.registerDevP2pCertificates(ALICE_NAME, rootCa.certificate, intermediateCa) } sslConfig.keyStore.get(true).also { it.registerDevP2pCertificates(ALICE_NAME, rootCa.certificate, intermediateCa) }
sslConfig.createTrustStore(rootCa.certificate) sslConfig.createTrustStore(rootCa.certificate)
val keyStore = sslConfig.keyStore.get()
val trustStore = sslConfig.trustStore.get()
val context = SSLContext.getInstance("TLS") val context = SSLContext.getInstance("TLS")
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) val keyManagerFactory = keyManagerFactory(sslConfig.keyStore.get())
keyManagerFactory.init(keyStore)
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustMgrFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) val trustMgrFactory = trustManagerFactory(sslConfig.trustStore.get())
trustMgrFactory.init(trustStore)
val trustManagers = trustMgrFactory.trustManagers val trustManagers = trustMgrFactory.trustManagers
context.init(keyManagers, trustManagers, newSecureRandom()) context.init(keyManagers, trustManagers, newSecureRandom())
@ -442,7 +443,7 @@ class ProtonWrapperTests {
amqpServer.use { amqpServer.use {
val connectionEvents = amqpServer.onConnection.toBlocking().iterator val connectionEvents = amqpServer.onConnection.toBlocking().iterator
amqpServer.start() amqpServer.start()
val sharedThreads = NioEventLoopGroup() val sharedThreads = NioEventLoopGroup(DefaultThreadFactory("sharedThreads"))
val amqpClient1 = createSharedThreadsClient(sharedThreads, 0) val amqpClient1 = createSharedThreadsClient(sharedThreads, 0)
val amqpClient2 = createSharedThreadsClient(sharedThreads, 1) val amqpClient2 = createSharedThreadsClient(sharedThreads, 1)
amqpClient1.start() amqpClient1.start()
@ -608,7 +609,7 @@ class ProtonWrapperTests {
listOf(NetworkHostAndPort("localhost", serverPort)), listOf(NetworkHostAndPort("localhost", serverPort)),
setOf(ALICE_NAME), setOf(ALICE_NAME),
amqpConfig, amqpConfig,
sharedThreadPool = sharedEventGroup) nettyThreading = AMQPClient.NettyThreading.Shared(sharedEventGroup))
} }
private fun createServer(port: Int, private fun createServer(port: Int,

View File

@ -3,6 +3,8 @@ package net.corda.services.messaging
import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.concurrent.openFuture
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.internal.config.MutualSslConfiguration import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactory
import org.apache.qpid.jms.JmsConnectionFactory import org.apache.qpid.jms.JmsConnectionFactory
import org.apache.qpid.jms.meta.JmsConnectionInfo import org.apache.qpid.jms.meta.JmsConnectionInfo
import org.apache.qpid.jms.provider.Provider import org.apache.qpid.jms.provider.Provider
@ -24,9 +26,7 @@ import javax.jms.Connection
import javax.jms.Message import javax.jms.Message
import javax.jms.MessageProducer import javax.jms.MessageProducer
import javax.jms.Session import javax.jms.Session
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
/** /**
* Simple AMQP client connecting to broker using JMS. * Simple AMQP client connecting to broker using JMS.
@ -59,12 +59,8 @@ class SimpleAMQPClient(private val target: NetworkHostAndPort, private val confi
private lateinit var connection: Connection private lateinit var connection: Connection
private fun sslContext(): SSLContext { private fun sslContext(): SSLContext {
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()).apply { val keyManagerFactory = keyManagerFactory(config.keyStore.get())
init(config.keyStore.get().value.internal, config.keyStore.entryPassword.toCharArray()) val trustManagerFactory = trustManagerFactory(config.trustStore.get())
}
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
init(config.trustStore.get().value.internal)
}
val sslContext = SSLContext.getInstance("TLS") val sslContext = SSLContext.getInstance("TLS")
val keyManagers = keyManagerFactory.keyManagers val keyManagers = keyManagerFactory.keyManagers
val trustManagers = trustManagerFactory.trustManagers val trustManagers = trustManagerFactory.trustManagers

View File

@ -5,8 +5,8 @@ import com.codahale.metrics.Gauge
import com.codahale.metrics.MetricRegistry import com.codahale.metrics.MetricRegistry
import com.google.common.collect.MutableClassToInstanceMap import com.google.common.collect.MutableClassToInstanceMap
import com.google.common.util.concurrent.MoreExecutors import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.ThreadFactoryBuilder
import com.zaxxer.hikari.pool.HikariPool import com.zaxxer.hikari.pool.HikariPool
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.common.logging.errorReporting.NodeDatabaseErrors import net.corda.common.logging.errorReporting.NodeDatabaseErrors
import net.corda.confidential.SwapIdentitiesFlow import net.corda.confidential.SwapIdentitiesFlow
import net.corda.core.CordaException import net.corda.core.CordaException
@ -73,6 +73,7 @@ import net.corda.core.toFuture
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.core.utilities.millis
import net.corda.core.utilities.minutes import net.corda.core.utilities.minutes
import net.corda.djvm.source.ApiSource import net.corda.djvm.source.ApiSource
import net.corda.djvm.source.EmptyApi import net.corda.djvm.source.EmptyApi
@ -172,6 +173,7 @@ import net.corda.nodeapi.internal.persistence.RestrictedEntityManager
import net.corda.nodeapi.internal.persistence.SchemaMigration import net.corda.nodeapi.internal.persistence.SchemaMigration
import net.corda.nodeapi.internal.persistence.contextDatabase import net.corda.nodeapi.internal.persistence.contextDatabase
import net.corda.nodeapi.internal.persistence.withoutDatabaseAccess import net.corda.nodeapi.internal.persistence.withoutDatabaseAccess
import net.corda.nodeapi.internal.namedThreadPoolExecutor
import org.apache.activemq.artemis.utils.ReusableLatch import org.apache.activemq.artemis.utils.ReusableLatch
import org.jolokia.jvmagent.JolokiaServer import org.jolokia.jvmagent.JolokiaServer
import org.jolokia.jvmagent.JolokiaServerConfig import org.jolokia.jvmagent.JolokiaServerConfig
@ -187,9 +189,6 @@ import java.util.ArrayList
import java.util.Properties import java.util.Properties
import java.util.concurrent.ExecutorService import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.ThreadPoolExecutor
import java.util.concurrent.TimeUnit
import java.util.concurrent.TimeUnit.MINUTES import java.util.concurrent.TimeUnit.MINUTES
import java.util.concurrent.TimeUnit.SECONDS import java.util.concurrent.TimeUnit.SECONDS
import java.util.function.Consumer import java.util.function.Consumer
@ -353,7 +352,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
private val cordappServices = MutableClassToInstanceMap.create<SerializeAsToken>() private val cordappServices = MutableClassToInstanceMap.create<SerializeAsToken>()
private val cordappTelemetryComponents = MutableClassToInstanceMap.create<TelemetryComponent>() private val cordappTelemetryComponents = MutableClassToInstanceMap.create<TelemetryComponent>()
private val shutdownExecutor = Executors.newSingleThreadExecutor() private val shutdownExecutor = Executors.newSingleThreadExecutor(DefaultThreadFactory("Shutdown"))
protected abstract val transactionVerifierWorkerCount: Int protected abstract val transactionVerifierWorkerCount: Int
/** /**
@ -808,7 +807,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
} else { } else {
1.days 1.days
} }
val executor = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("Network Map Updater")) val executor = Executors.newSingleThreadScheduledExecutor(NamedThreadFactory("NetworkMapPublisher"))
executor.submit(object : Runnable { executor.submit(object : Runnable {
override fun run() { override fun run() {
val republishInterval = try { val republishInterval = try {
@ -917,13 +916,12 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
} }
// Start with 1 thread and scale up to the configured thread pool size if needed // Start with 1 thread and scale up to the configured thread pool size if needed
// Parameters of [ThreadPoolExecutor] based on [Executors.newFixedThreadPool] // Parameters of [ThreadPoolExecutor] based on [Executors.newFixedThreadPool]
return ThreadPoolExecutor( return namedThreadPoolExecutor(
1, corePoolSize = 1,
numberOfThreads, maxPoolSize = numberOfThreads,
0L, idleKeepAlive = 0.millis,
TimeUnit.MILLISECONDS, poolName = "flow-external-operation-thread",
LinkedBlockingQueue<Runnable>(), daemonThreads = true
ThreadFactoryBuilder().setNameFormat("flow-external-operation-thread").setDaemon(true).build()
) )
} }
@ -1174,7 +1172,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
networkParameters: NetworkParameters) networkParameters: NetworkParameters)
protected open fun makeVaultService(keyManagementService: KeyManagementService, protected open fun makeVaultService(keyManagementService: KeyManagementService,
services: ServicesForResolution, services: NodeServicesForResolution,
database: CordaPersistence, database: CordaPersistence,
cordappLoader: CordappLoader): VaultServiceInternal { cordappLoader: CordappLoader): VaultServiceInternal {
return NodeVaultService(platformClock, keyManagementService, services, database, schemaService, cordappLoader.appClassLoader) return NodeVaultService(platformClock, keyManagementService, services, database, schemaService, cordappLoader.appClassLoader)

View File

@ -415,12 +415,13 @@ open class Node(configuration: NodeConfiguration,
} }
private fun makeBridgeControlListener(serverAddress: NetworkHostAndPort, networkParameters: NetworkParameters): BridgeControlListener { private fun makeBridgeControlListener(serverAddress: NetworkHostAndPort, networkParameters: NetworkParameters): BridgeControlListener {
val artemisMessagingClientFactory = { val artemisMessagingClientFactory = { threadPoolName: String ->
ArtemisMessagingClient( ArtemisMessagingClient(
configuration.p2pSslOptions, configuration.p2pSslOptions,
serverAddress, serverAddress,
networkParameters.maxMessageSize, networkParameters.maxMessageSize,
failoverCallback = { errorAndTerminate("ArtemisMessagingClient failed. Shutting down.", null) } failoverCallback = { errorAndTerminate("ArtemisMessagingClient failed. Shutting down.", null) },
threadPoolName = threadPoolName
) )
} }
return BridgeControlListener( return BridgeControlListener(
@ -431,7 +432,8 @@ open class Node(configuration: NodeConfiguration,
networkParameters.maxMessageSize, networkParameters.maxMessageSize,
configuration.crlCheckSoftFail.toRevocationConfig(), configuration.crlCheckSoftFail.toRevocationConfig(),
false, false,
artemisMessagingClientFactory) artemisMessagingClientFactory
)
} }
private fun startLocalRpcBroker(securityManager: RPCSecurityManager): BrokerAddresses? { private fun startLocalRpcBroker(securityManager: RPCSecurityManager): BrokerAddresses? {

View File

@ -0,0 +1,15 @@
package net.corda.node.internal
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionResolutionException
import net.corda.core.node.ServicesForResolution
import java.util.LinkedHashSet
interface NodeServicesForResolution : ServicesForResolution {
@Throws(TransactionResolutionException::class)
override fun loadStates(stateRefs: Set<StateRef>): Set<StateAndRef<ContractState>> = loadStates(stateRefs, LinkedHashSet())
fun <T : ContractState, C : MutableCollection<StateAndRef<T>>> loadStates(input: Iterable<StateRef>, output: C): C
}

View File

@ -1,16 +1,26 @@
package net.corda.node.internal package net.corda.node.internal
import net.corda.core.contracts.* import net.corda.core.contracts.Attachment
import net.corda.core.contracts.AttachmentResolutionException
import net.corda.core.contracts.ContractAttachment
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionResolutionException
import net.corda.core.contracts.TransactionState
import net.corda.core.cordapp.CordappProvider import net.corda.core.cordapp.CordappProvider
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.SerializedStateAndRef import net.corda.core.internal.SerializedStateAndRef
import net.corda.core.internal.uncheckedCast
import net.corda.core.node.NetworkParameters import net.corda.core.node.NetworkParameters
import net.corda.core.node.ServicesForResolution
import net.corda.core.node.services.AttachmentStorage import net.corda.core.node.services.AttachmentStorage
import net.corda.core.node.services.IdentityService import net.corda.core.node.services.IdentityService
import net.corda.core.node.services.NetworkParametersService import net.corda.core.node.services.NetworkParametersService
import net.corda.core.node.services.TransactionStorage import net.corda.core.node.services.TransactionStorage
import net.corda.core.transactions.BaseTransaction
import net.corda.core.transactions.ContractUpgradeWireTransaction import net.corda.core.transactions.ContractUpgradeWireTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.core.transactions.WireTransaction.Companion.resolveStateRefBinaryComponent import net.corda.core.transactions.WireTransaction.Companion.resolveStateRefBinaryComponent
@ -20,31 +30,28 @@ data class ServicesForResolutionImpl(
override val cordappProvider: CordappProvider, override val cordappProvider: CordappProvider,
override val networkParametersService: NetworkParametersService, override val networkParametersService: NetworkParametersService,
private val validatedTransactions: TransactionStorage private val validatedTransactions: TransactionStorage
) : ServicesForResolution { ) : NodeServicesForResolution {
override val networkParameters: NetworkParameters get() = networkParametersService.lookup(networkParametersService.currentHash) ?: override val networkParameters: NetworkParameters get() = networkParametersService.lookup(networkParametersService.currentHash) ?:
throw IllegalArgumentException("No current parameters in network parameters storage") throw IllegalArgumentException("No current parameters in network parameters storage")
@Throws(TransactionResolutionException::class) @Throws(TransactionResolutionException::class)
override fun loadState(stateRef: StateRef): TransactionState<*> { override fun loadState(stateRef: StateRef): TransactionState<*> {
val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash) return toBaseTransaction(stateRef.txhash).outputs[stateRef.index]
return stx.resolveBaseTransaction(this).outputs[stateRef.index]
} }
@Throws(TransactionResolutionException::class) override fun <T : ContractState, C : MutableCollection<StateAndRef<T>>> loadStates(input: Iterable<StateRef>, output: C): C {
override fun loadStates(stateRefs: Set<StateRef>): Set<StateAndRef<ContractState>> { val baseTxs = HashMap<SecureHash, BaseTransaction>()
return stateRefs.groupBy { it.txhash }.flatMap { return input.mapTo(output) { stateRef ->
val stx = validatedTransactions.getTransaction(it.key) ?: throw TransactionResolutionException(it.key) val baseTx = baseTxs.computeIfAbsent(stateRef.txhash, ::toBaseTransaction)
val baseTx = stx.resolveBaseTransaction(this) StateAndRef(uncheckedCast(baseTx.outputs[stateRef.index]), stateRef)
it.value.map { ref -> StateAndRef(baseTx.outputs[ref.index], ref) } }
}.toSet()
} }
@Throws(TransactionResolutionException::class, AttachmentResolutionException::class) @Throws(TransactionResolutionException::class, AttachmentResolutionException::class)
override fun loadContractAttachment(stateRef: StateRef): Attachment { override fun loadContractAttachment(stateRef: StateRef): Attachment {
// We may need to recursively chase transactions if there are notary changes. // We may need to recursively chase transactions if there are notary changes.
fun inner(stateRef: StateRef, forContractClassName: String?): Attachment { fun inner(stateRef: StateRef, forContractClassName: String?): Attachment {
val ctx = validatedTransactions.getTransaction(stateRef.txhash)?.coreTransaction val ctx = getSignedTransaction(stateRef.txhash).coreTransaction
?: throw TransactionResolutionException(stateRef.txhash)
when (ctx) { when (ctx) {
is WireTransaction -> { is WireTransaction -> {
val transactionState = ctx.outRef<ContractState>(stateRef.index).state val transactionState = ctx.outRef<ContractState>(stateRef.index).state
@ -69,4 +76,10 @@ data class ServicesForResolutionImpl(
} }
return inner(stateRef, null) return inner(stateRef, null)
} }
private fun toBaseTransaction(txhash: SecureHash): BaseTransaction = getSignedTransaction(txhash).resolveBaseTransaction(this)
private fun getSignedTransaction(txhash: SecureHash): SignedTransaction {
return validatedTransactions.getTransaction(txhash) ?: throw TransactionResolutionException(txhash)
}
} }

View File

@ -135,12 +135,12 @@ class BrokerJaasLoginModule : BaseBrokerJaasLoginModule() {
Pair(ArtemisMessagingComponent.NODE_RPC_USER, listOf(RolePrincipal(NODE_RPC_ROLE))) Pair(ArtemisMessagingComponent.NODE_RPC_USER, listOf(RolePrincipal(NODE_RPC_ROLE)))
} }
ArtemisMessagingComponent.PEER_USER -> { ArtemisMessagingComponent.PEER_USER -> {
requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." } val p2pJaasConfig = requireNotNull(p2pJaasConfig) { "Attempted to connect as a peer to the rpc broker." }
requireTls(certificates) requireTls(certificates)
// This check is redundant as it was performed already during the SSL handshake // This check is redundant as it was performed already during the SSL handshake
CertificateChainCheckPolicy.RootMustMatch.createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates) CertificateChainCheckPolicy.RootMustMatch
CertificateChainCheckPolicy.RevocationCheck(p2pJaasConfig!!.revocationMode) .createCheck(p2pJaasConfig.keyStore, p2pJaasConfig.trustStore)
.createCheck(p2pJaasConfig!!.keyStore, p2pJaasConfig!!.trustStore).checkCertificateChain(certificates) .checkCertificateChain(certificates)
Pair(certificates.first().subjectDN.name, listOf(RolePrincipal(PEER_ROLE))) Pair(certificates.first().subjectDN.name, listOf(RolePrincipal(PEER_ROLE)))
} }
else -> { else -> {

View File

@ -2,17 +2,9 @@ package net.corda.node.internal.artemis
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.crypto.X509CertificateFactory
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.certPathToString
import java.security.KeyStore import java.security.KeyStore
import java.security.cert.CertPathValidator
import java.security.cert.CertPathValidatorException
import java.security.cert.CertificateException import java.security.cert.CertificateException
import java.security.cert.PKIXBuilderParameters
import java.security.cert.X509CertSelector
sealed class CertificateChainCheckPolicy { sealed class CertificateChainCheckPolicy {
companion object { companion object {
@ -92,33 +84,4 @@ sealed class CertificateChainCheckPolicy {
} }
} }
} }
class RevocationCheck(val revocationConfig: RevocationConfig) : CertificateChainCheckPolicy() {
constructor(revocationMode: RevocationConfig.Mode) : this(RevocationConfigImpl(revocationMode))
override fun createCheck(keyStore: KeyStore, trustStore: KeyStore): Check {
return object : Check {
override fun checkCertificateChain(theirChain: Array<java.security.cert.X509Certificate>) {
// Convert javax.security.cert.X509Certificate to java.security.cert.X509Certificate.
val chain = theirChain.map { X509CertificateFactory().generateCertificate(it.encoded.inputStream()) }
log.info("Check Client Certpath:\r\n${certPathToString(chain.toTypedArray())}")
// Drop the last certificate which must be a trusted root (validated by RootMustMatch).
// Assume that there is no more trusted roots (or corresponding public keys) in the remaining chain.
// See PKIXValidator.engineValidate() for reference implementation.
val certPath = X509Utilities.buildCertPath(chain.dropLast(1))
val certPathValidator = CertPathValidator.getInstance("PKIX")
val pkixRevocationChecker = revocationConfig.createPKIXRevocationChecker()
val params = PKIXBuilderParameters(trustStore, X509CertSelector())
params.addCertPathChecker(pkixRevocationChecker)
try {
certPathValidator.validate(certPath, params)
} catch (ex: CertPathValidatorException) {
log.error("Bad certificate path", ex)
throw ex
}
}
}
}
}
} }

View File

@ -2,7 +2,6 @@ package net.corda.node.migration
import liquibase.database.Database import liquibase.database.Database
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.node.services.Vault import net.corda.core.node.services.Vault
import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.MappedSchema
@ -19,6 +18,7 @@ import net.corda.node.services.persistence.DBTransactionStorage
import net.corda.node.services.persistence.NodeAttachmentService import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.node.services.vault.NodeVaultService import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSchemaV1 import net.corda.node.services.vault.VaultSchemaV1
import net.corda.node.services.vault.toStateRef
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.nodeapi.internal.persistence.DatabaseTransaction
import net.corda.nodeapi.internal.persistence.SchemaMigration import net.corda.nodeapi.internal.persistence.SchemaMigration
@ -62,8 +62,7 @@ class VaultStateMigration : CordaMigration() {
private fun getStateAndRef(persistentState: VaultSchemaV1.VaultStates): StateAndRef<ContractState> { private fun getStateAndRef(persistentState: VaultSchemaV1.VaultStates): StateAndRef<ContractState> {
val persistentStateRef = persistentState.stateRef ?: val persistentStateRef = persistentState.stateRef ?:
throw VaultStateMigrationException("Persistent state ref missing from state") throw VaultStateMigrationException("Persistent state ref missing from state")
val txHash = SecureHash.create(persistentStateRef.txId) val stateRef = persistentStateRef.toStateRef()
val stateRef = StateRef(txHash, persistentStateRef.index)
val state = try { val state = try {
servicesForResolution.loadState(stateRef) servicesForResolution.loadState(stateRef)
} catch (e: Exception) { } catch (e: Exception) {

View File

@ -2,6 +2,7 @@ package net.corda.node.services.events
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationContext
import net.corda.core.context.InvocationOrigin import net.corda.core.context.InvocationOrigin
@ -148,7 +149,7 @@ class NodeSchedulerService(private val clock: CordaClock,
// from the database // from the database
private val startingStateRefs: MutableSet<ScheduledStateRef> = ConcurrentHashMap.newKeySet<ScheduledStateRef>() private val startingStateRefs: MutableSet<ScheduledStateRef> = ConcurrentHashMap.newKeySet<ScheduledStateRef>()
private val mutex = ThreadBox(InnerState()) private val mutex = ThreadBox(InnerState())
private val schedulerTimerExecutor = Executors.newSingleThreadExecutor() private val schedulerTimerExecutor = Executors.newSingleThreadExecutor(DefaultThreadFactory("SchedulerService"))
// if there's nothing to do, check every minute if something fell through the cracks. // if there's nothing to do, check every minute if something fell through the cracks.
// any new state should trigger a reschedule immediately if nothing is scheduled, so I would not expect // any new state should trigger a reschedule immediately if nothing is scheduled, so I would not expect

View File

@ -2,8 +2,8 @@ package net.corda.node.services.events
import net.corda.core.contracts.ScheduledStateRef import net.corda.core.contracts.ScheduledStateRef
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
import net.corda.core.crypto.SecureHash
import net.corda.core.schemas.PersistentStateRef import net.corda.core.schemas.PersistentStateRef
import net.corda.node.services.vault.toStateRef
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
interface ScheduledFlowRepository { interface ScheduledFlowRepository {
@ -25,9 +25,8 @@ class PersistentScheduledFlowRepository(val database: CordaPersistence) : Schedu
} }
private fun fromPersistentEntity(scheduledStateRecord: NodeSchedulerService.PersistentScheduledState): Pair<StateRef, ScheduledStateRef> { private fun fromPersistentEntity(scheduledStateRecord: NodeSchedulerService.PersistentScheduledState): Pair<StateRef, ScheduledStateRef> {
val txId = scheduledStateRecord.output.txId val stateRef = scheduledStateRecord.output.toStateRef()
val index = scheduledStateRecord.output.index return Pair(stateRef, ScheduledStateRef(stateRef, scheduledStateRecord.scheduledAt))
return Pair(StateRef(SecureHash.create(txId), index), ScheduledStateRef(StateRef(SecureHash.create(txId), index), scheduledStateRecord.scheduledAt))
} }
override fun delete(key: StateRef): Boolean { override fun delete(key: StateRef): Boolean {

View File

@ -7,9 +7,16 @@ import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.node.internal.artemis.* import net.corda.node.internal.artemis.ArtemisBroker
import net.corda.node.internal.artemis.BrokerAddresses
import net.corda.node.internal.artemis.BrokerJaasLoginModule
import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.NODE_P2P_ROLE import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.NODE_P2P_ROLE
import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.PEER_ROLE import net.corda.node.internal.artemis.BrokerJaasLoginModule.Companion.PEER_ROLE
import net.corda.node.internal.artemis.NodeJaasConfig
import net.corda.node.internal.artemis.P2PJaasConfig
import net.corda.node.internal.artemis.SecureArtemisConfiguration
import net.corda.node.internal.artemis.UserValidationPlugin
import net.corda.node.internal.artemis.isBindingError
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.utilities.artemis.startSynchronously import net.corda.node.utilities.artemis.startSynchronously
import net.corda.nodeapi.internal.AmqpMessageSizeChecksInterceptor import net.corda.nodeapi.internal.AmqpMessageSizeChecksInterceptor
@ -21,7 +28,10 @@ import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.SECURITY_INVALIDATION_INTERVAL import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.SECURITY_INVALIDATION_INTERVAL
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pAcceptorTcpTransport import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pAcceptorTcpTransport
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfigImpl
import net.corda.nodeapi.internal.protonwrapper.netty.trustManagerFactoryWithRevocation
import net.corda.nodeapi.internal.requireOnDefaultFileSystem import net.corda.nodeapi.internal.requireOnDefaultFileSystem
import net.corda.nodeapi.internal.revocation.CertDistPointCrlSource
import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration import org.apache.activemq.artemis.api.config.ActiveMQDefaultConfiguration
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl
@ -32,9 +42,7 @@ import org.apache.activemq.artemis.core.security.Role
import org.apache.activemq.artemis.core.server.ActiveMQServer import org.apache.activemq.artemis.core.server.ActiveMQServer
import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl
import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager
import java.io.IOException
import java.lang.Long.max import java.lang.Long.max
import java.security.KeyStoreException
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
import javax.security.auth.login.AppConfigurationEntry import javax.security.auth.login.AppConfigurationEntry
import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED
@ -57,8 +65,10 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
private val messagingServerAddress: NetworkHostAndPort, private val messagingServerAddress: NetworkHostAndPort,
private val maxMessageSize: Int, private val maxMessageSize: Int,
private val journalBufferTimeout : Int? = null, private val journalBufferTimeout : Int? = null,
private val threadPoolName: String = "ArtemisServer", private val threadPoolName: String = "P2PServer",
private val trace: Boolean = false) : ArtemisBroker, SingletonSerializeAsToken() { private val trace: Boolean = false,
private val distPointCrlSource: CertDistPointCrlSource = CertDistPointCrlSource.SINGLETON,
private val remotingThreads: Int? = null) : ArtemisBroker, SingletonSerializeAsToken() {
companion object { companion object {
private val log = contextLogger() private val log = contextLogger()
} }
@ -92,7 +102,7 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
override val started: Boolean override val started: Boolean
get() = activeMQServer.isStarted get() = activeMQServer.isStarted
@Throws(IOException::class, AddressBindingException::class, KeyStoreException::class) @Suppress("ThrowsCount")
private fun configureAndStartServer() { private fun configureAndStartServer() {
val artemisConfig = createArtemisConfig() val artemisConfig = createArtemisConfig()
val securityManager = createArtemisSecurityManager() val securityManager = createArtemisSecurityManager()
@ -132,11 +142,23 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
// The transaction cache is configurable, and drives other cache sizes. // The transaction cache is configurable, and drives other cache sizes.
globalMaxSize = max(config.transactionCacheSizeBytes, 10L * maxMessageSize) globalMaxSize = max(config.transactionCacheSizeBytes, 10L * maxMessageSize)
val revocationMode = if (config.crlCheckArtemisServer) {
if (config.crlCheckSoftFail) RevocationConfig.Mode.SOFT_FAIL else RevocationConfig.Mode.HARD_FAIL
} else {
RevocationConfig.Mode.OFF
}
val trustManagerFactory = trustManagerFactoryWithRevocation(
config.p2pSslOptions.trustStore.get(),
RevocationConfigImpl(revocationMode),
distPointCrlSource
)
addAcceptorConfiguration(p2pAcceptorTcpTransport( addAcceptorConfiguration(p2pAcceptorTcpTransport(
NetworkHostAndPort(messagingServerAddress.host, messagingServerAddress.port), NetworkHostAndPort(messagingServerAddress.host, messagingServerAddress.port),
config.p2pSslOptions, config.p2pSslOptions,
trustManagerFactory,
threadPoolName = threadPoolName, threadPoolName = threadPoolName,
trace = trace trace = trace,
remotingThreads = remotingThreads
)) ))
// Enable built in message deduplication. Note we still have to do our own as the delayed commits // Enable built in message deduplication. Note we still have to do our own as the delayed commits
// and our own definition of commit mean that the built in deduplication cannot remove all duplicates. // and our own definition of commit mean that the built in deduplication cannot remove all duplicates.

View File

@ -10,6 +10,8 @@ import io.netty.handler.ssl.SslHandshakeTimeoutException
import net.corda.core.internal.declaredField import net.corda.core.internal.declaredField
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.nodeapi.internal.ArtemisTcpTransport import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.protonwrapper.netty.sslDelegatedTaskExecutor
import net.corda.nodeapi.internal.setThreadPoolName
import org.apache.activemq.artemis.api.core.BaseInterceptor import org.apache.activemq.artemis.api.core.BaseInterceptor
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptor
import org.apache.activemq.artemis.core.server.balancing.RedirectHandler import org.apache.activemq.artemis.core.server.balancing.RedirectHandler
@ -19,14 +21,18 @@ import org.apache.activemq.artemis.spi.core.remoting.Acceptor
import org.apache.activemq.artemis.spi.core.remoting.AcceptorFactory import org.apache.activemq.artemis.spi.core.remoting.AcceptorFactory
import org.apache.activemq.artemis.spi.core.remoting.BufferHandler import org.apache.activemq.artemis.spi.core.remoting.BufferHandler
import org.apache.activemq.artemis.spi.core.remoting.ServerConnectionLifeCycleListener import org.apache.activemq.artemis.spi.core.remoting.ServerConnectionLifeCycleListener
import org.apache.activemq.artemis.spi.core.remoting.ssl.OpenSSLContextFactoryProvider
import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextFactoryProvider
import org.apache.activemq.artemis.utils.ConfigurationHelper import org.apache.activemq.artemis.utils.ConfigurationHelper
import org.apache.activemq.artemis.utils.actors.OrderedExecutor import org.apache.activemq.artemis.utils.actors.OrderedExecutor
import java.net.SocketAddress
import java.nio.channels.ClosedChannelException import java.nio.channels.ClosedChannelException
import java.time.Duration import java.time.Duration
import java.util.concurrent.Executor import java.util.concurrent.Executor
import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.ScheduledExecutorService
import java.util.regex.Pattern import java.util.regex.Pattern
import javax.net.ssl.SSLEngine import javax.net.ssl.SSLEngine
import javax.net.ssl.SSLPeerUnverifiedException
@Suppress("unused") // Used via reflection in ArtemisTcpTransport @Suppress("unused") // Used via reflection in ArtemisTcpTransport
class NodeNettyAcceptorFactory : AcceptorFactory { class NodeNettyAcceptorFactory : AcceptorFactory {
@ -36,10 +42,23 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
handler: BufferHandler?, handler: BufferHandler?,
listener: ServerConnectionLifeCycleListener?, listener: ServerConnectionLifeCycleListener?,
threadPool: Executor, threadPool: Executor,
scheduledThreadPool: ScheduledExecutorService?, scheduledThreadPool: ScheduledExecutorService,
protocolMap: MutableMap<String, ProtocolManager<BaseInterceptor<*>, RedirectHandler<*>>>?): Acceptor { protocolMap: MutableMap<String, ProtocolManager<BaseInterceptor<*>, RedirectHandler<*>>>?): Acceptor {
val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "Acceptor", configuration)
threadPool.setThreadPoolName("$threadPoolName-artemis")
scheduledThreadPool.setThreadPoolName("$threadPoolName-artemis-scheduler")
val failureExecutor = OrderedExecutor(threadPool) val failureExecutor = OrderedExecutor(threadPool)
return NodeNettyAcceptor(name, clusterConnection, configuration, handler, listener, scheduledThreadPool, failureExecutor, protocolMap) return NodeNettyAcceptor(
name,
clusterConnection,
configuration,
handler,
listener,
scheduledThreadPool,
failureExecutor,
protocolMap,
"$threadPoolName-netty"
)
} }
@ -50,14 +69,21 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
listener: ServerConnectionLifeCycleListener?, listener: ServerConnectionLifeCycleListener?,
scheduledThreadPool: ScheduledExecutorService?, scheduledThreadPool: ScheduledExecutorService?,
failureExecutor: Executor, failureExecutor: Executor,
protocolMap: MutableMap<String, ProtocolManager<BaseInterceptor<*>, RedirectHandler<*>>>?) : protocolMap: MutableMap<String, ProtocolManager<BaseInterceptor<*>, RedirectHandler<*>>>?,
private val threadPoolName: String) :
NettyAcceptor(name, clusterConnection, configuration, handler, listener, scheduledThreadPool, failureExecutor, protocolMap) NettyAcceptor(name, clusterConnection, configuration, handler, listener, scheduledThreadPool, failureExecutor, protocolMap)
{ {
companion object { companion object {
private val defaultThreadPoolNamePattern = Pattern.compile("""Thread-(\d+) \(activemq-netty-threads\)""") private val defaultThreadPoolNamePattern = Pattern.compile("""Thread-(\d+) \(activemq-netty-threads\)""")
init {
// Make sure Artemis isn't using another (Open)SSLContextFactory
check(SSLContextFactoryProvider.getSSLContextFactory() is NodeSSLContextFactory)
check(OpenSSLContextFactoryProvider.getOpenSSLContextFactory() is NodeOpenSSLContextFactory)
}
} }
private val threadPoolName = ConfigurationHelper.getStringProperty(ArtemisTcpTransport.THREAD_POOL_NAME_NAME, "NodeNettyAcceptor", configuration) private val sslDelegatedTaskExecutor = sslDelegatedTaskExecutor(threadPoolName)
private val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration) private val trace = ConfigurationHelper.getBooleanProperty(ArtemisTcpTransport.TRACE_NAME, false, configuration)
@Synchronized @Synchronized
@ -71,11 +97,17 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
} }
} }
@Synchronized
override fun stop() {
super.stop()
sslDelegatedTaskExecutor.shutdown()
}
@Synchronized @Synchronized
override fun getSslHandler(alloc: ByteBufAllocator?, peerHost: String?, peerPort: Int): SslHandler { override fun getSslHandler(alloc: ByteBufAllocator?, peerHost: String?, peerPort: Int): SslHandler {
applyThreadPoolName() applyThreadPoolName()
val engine = super.getSslHandler(alloc, peerHost, peerPort).engine() val engine = super.getSslHandler(alloc, peerHost, peerPort).engine()
val sslHandler = NodeAcceptorSslHandler(engine, trace) val sslHandler = NodeAcceptorSslHandler(engine, sslDelegatedTaskExecutor, trace)
val handshakeTimeout = configuration[ArtemisTcpTransport.SSL_HANDSHAKE_TIMEOUT_NAME] as Duration? val handshakeTimeout = configuration[ArtemisTcpTransport.SSL_HANDSHAKE_TIMEOUT_NAME] as Duration?
if (handshakeTimeout != null) { if (handshakeTimeout != null) {
sslHandler.handshakeTimeoutMillis = handshakeTimeout.toMillis() sslHandler.handshakeTimeoutMillis = handshakeTimeout.toMillis()
@ -95,13 +127,15 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
} }
private class NodeAcceptorSslHandler(engine: SSLEngine, private val trace: Boolean) : SslHandler(engine) { private class NodeAcceptorSslHandler(engine: SSLEngine,
delegatedTaskExecutor: Executor,
private val trace: Boolean) : SslHandler(engine, delegatedTaskExecutor) {
companion object { companion object {
private val logger = contextLogger() private val logger = contextLogger()
} }
override fun handlerAdded(ctx: ChannelHandlerContext) { override fun handlerAdded(ctx: ChannelHandlerContext) {
logHandshake() logHandshake(ctx.channel().remoteAddress())
super.handlerAdded(ctx) super.handlerAdded(ctx)
// Unfortunately NettyAcceptor does not let us add extra child handlers, so we have to add our logger this way. // Unfortunately NettyAcceptor does not let us add extra child handlers, so we have to add our logger this way.
if (trace) { if (trace) {
@ -109,17 +143,22 @@ class NodeNettyAcceptorFactory : AcceptorFactory {
} }
} }
private fun logHandshake() { private fun logHandshake(remoteAddress: SocketAddress) {
val start = System.currentTimeMillis() val start = System.currentTimeMillis()
handshakeFuture().addListener { handshakeFuture().addListener {
val duration = System.currentTimeMillis() - start val duration = System.currentTimeMillis() - start
val peer = try {
engine().session.peerPrincipal
} catch (e: SSLPeerUnverifiedException) {
remoteAddress
}
when { when {
it.isSuccess -> logger.info("SSL handshake completed in ${duration}ms with ${engine().session.peerPrincipal}") it.isSuccess -> logger.info("SSL handshake completed in ${duration}ms with $peer")
it.isCancelled -> logger.warn("SSL handshake cancelled after ${duration}ms") it.isCancelled -> logger.warn("SSL handshake cancelled after ${duration}ms with $peer")
else -> when (it.cause()) { else -> when (it.cause()) {
is ClosedChannelException -> logger.warn("SSL handshake closed early after ${duration}ms") is ClosedChannelException -> logger.warn("SSL handshake closed early after ${duration}ms with $peer")
is SslHandshakeTimeoutException -> logger.warn("SSL handshake timed out after ${duration}ms") is SslHandshakeTimeoutException -> logger.warn("SSL handshake timed out after ${duration}ms with $peer")
else -> logger.warn("SSL handshake failed after ${duration}ms", it.cause()) else -> logger.warn("SSL handshake failed after ${duration}ms with $peer", it.cause())
} }
} }
} }

View File

@ -0,0 +1,59 @@
package net.corda.node.services.messaging
import io.netty.handler.ssl.SslContext
import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.SslProvider
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.TRUST_MANAGER_FACTORY_NAME
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.protonwrapper.netty.createAndInitSslContext
import net.corda.nodeapi.internal.protonwrapper.netty.keyManagerFactory
import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultOpenSSLContextFactory
import org.apache.activemq.artemis.core.remoting.impl.ssl.DefaultSSLContextFactory
import org.apache.activemq.artemis.spi.core.remoting.ssl.SSLContextConfig
import java.nio.file.Paths
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
class NodeSSLContextFactory : DefaultSSLContextFactory() {
override fun getSSLContext(config: SSLContextConfig, additionalOpts: Map<String, Any>): SSLContext {
val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?
return if (trustManagerFactory != null) {
createAndInitSslContext(loadKeyManagerFactory(config), trustManagerFactory)
} else {
super.getSSLContext(config, additionalOpts)
}
}
override fun getPriority(): Int {
// We make sure this factory is the one that's chosen, so any sufficiently large value will do.
return 15
}
}
class NodeOpenSSLContextFactory : DefaultOpenSSLContextFactory() {
override fun getServerSslContext(config: SSLContextConfig, additionalOpts: Map<String, Any>): SslContext {
val trustManagerFactory = additionalOpts[TRUST_MANAGER_FACTORY_NAME] as TrustManagerFactory?
return if (trustManagerFactory != null) {
SslContextBuilder
.forServer(loadKeyManagerFactory(config))
.sslProvider(SslProvider.OPENSSL)
.trustManager(trustManagerFactory)
.build()
} else {
super.getServerSslContext(config, additionalOpts)
}
}
override fun getPriority(): Int {
// We make sure this factory is the one that's chosen, so any sufficiently large value will do.
return 15
}
}
private fun loadKeyManagerFactory(config: SSLContextConfig): KeyManagerFactory {
val keyStore = CertificateStore.fromFile(Paths.get(config.keystorePath), config.keystorePassword, config.keystorePassword, false)
return keyManagerFactory(keyStore)
}

View File

@ -74,7 +74,7 @@ class NetworkMapUpdater(private val networkMapCache: NetworkMapCacheInternal,
} }
private val parametersUpdatesTrack = PublishSubject.create<ParametersUpdateInfo>() private val parametersUpdatesTrack = PublishSubject.create<ParametersUpdateInfo>()
private val networkMapPoller = ScheduledThreadPoolExecutor(1, NamedThreadFactory("Network Map Updater Thread")).apply { private val networkMapPoller = ScheduledThreadPoolExecutor(1, NamedThreadFactory("NetworkMapUpdater")).apply {
executeExistingDelayedTasksAfterShutdownPolicy = false executeExistingDelayedTasksAfterShutdownPolicy = false
} }
private var newNetworkParameters: Pair<ParametersUpdate, SignedNetworkParameters>? = null private var newNetworkParameters: Pair<ParametersUpdate, SignedNetworkParameters>? = null
@ -261,9 +261,12 @@ class NetworkMapUpdater(private val networkMapCache: NetworkMapCacheInternal,
//as HTTP GET is mostly IO bound, use more threads than CPU's //as HTTP GET is mostly IO bound, use more threads than CPU's
//maximum threads to use = 24, as if we did not limit this on large machines it could result in 100's of concurrent requests //maximum threads to use = 24, as if we did not limit this on large machines it could result in 100's of concurrent requests
val threadsToUseForNetworkMapDownload = min(Runtime.getRuntime().availableProcessors() * 4, 24) val threadsToUseForNetworkMapDownload = min(Runtime.getRuntime().availableProcessors() * 4, 24)
val executorToUseForDownloadingNodeInfos = Executors.newFixedThreadPool(threadsToUseForNetworkMapDownload, NamedThreadFactory("NetworkMapUpdaterNodeInfoDownloadThread")) val executorToUseForDownloadingNodeInfos = Executors.newFixedThreadPool(
threadsToUseForNetworkMapDownload,
NamedThreadFactory("NetworkMapUpdaterNodeInfoDownload")
)
//DB insert is single threaded - use a single threaded executor for it. //DB insert is single threaded - use a single threaded executor for it.
val executorToUseForInsertionIntoDB = Executors.newSingleThreadExecutor(NamedThreadFactory("NetworkMapUpdateDBInsertThread")) val executorToUseForInsertionIntoDB = Executors.newSingleThreadExecutor(NamedThreadFactory("NetworkMapUpdateDBInsert"))
val hashesToFetch = (allHashesFromNetworkMap - allNodeHashes) val hashesToFetch = (allHashesFromNetworkMap - allNodeHashes)
val networkMapDownloadStartTime = System.currentTimeMillis() val networkMapDownloadStartTime = System.currentTimeMillis()
if (hashesToFetch.isNotEmpty()) { if (hashesToFetch.isNotEmpty()) {

View File

@ -22,8 +22,7 @@ class InternalRPCMessagingClient(val sslConfig: MutualSslConfiguration, val serv
private var rpcServer: RPCServer? = null private var rpcServer: RPCServer? = null
fun init(rpcOps: List<RPCOps>, securityManager: RPCSecurityManager, cacheFactory: NamedCacheFactory) = synchronized(this) { fun init(rpcOps: List<RPCOps>, securityManager: RPCSecurityManager, cacheFactory: NamedCacheFactory) = synchronized(this) {
val tcpTransport = ArtemisTcpTransport.rpcInternalClientTcpTransport(serverAddress, sslConfig, threadPoolName = "RPCClient")
val tcpTransport = ArtemisTcpTransport.rpcInternalClientTcpTransport(serverAddress, sslConfig)
locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply { locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply {
// Never time out on our loopback Artemis connections. If we switch back to using the InVM transport this // Never time out on our loopback Artemis connections. If we switch back to using the InVM transport this
// would be the default and the two lines below can be deleted. // would be the default and the two lines below can be deleted.

View File

@ -30,10 +30,10 @@ internal class RpcBrokerConfiguration(baseDirectory: Path, maxMessageSize: Int,
setDirectories(baseDirectory) setDirectories(baseDirectory)
val acceptorConfigurationsSet = mutableSetOf( val acceptorConfigurationsSet = mutableSetOf(
rpcAcceptorTcpTransport(address, sslOptions, useSsl) rpcAcceptorTcpTransport(address, sslOptions, enableSSL = useSsl, threadPoolName = "RPCServer")
) )
adminAddress?.let { adminAddress?.let {
acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration) acceptorConfigurationsSet += rpcInternalAcceptorTcpTransport(it, nodeConfiguration, threadPoolName = "RPCServerAdmin")
} }
acceptorConfigurations = acceptorConfigurationsSet acceptorConfigurations = acceptorConfigurationsSet

View File

@ -1,5 +1,6 @@
package net.corda.node.services.statemachine package net.corda.node.services.statemachine
import io.netty.util.concurrent.DefaultThreadFactory
import net.corda.core.flows.FlowSession import net.corda.core.flows.FlowSession
import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.FlowStateMachine
@ -22,10 +23,6 @@ internal class FlowMonitor(
) : LifecycleSupport { ) : LifecycleSupport {
private companion object { private companion object {
private fun defaultScheduler(): ScheduledExecutorService {
return Executors.newSingleThreadScheduledExecutor()
}
private val logger = loggerFor<FlowMonitor>() private val logger = loggerFor<FlowMonitor>()
} }
@ -36,7 +33,7 @@ internal class FlowMonitor(
override fun start() { override fun start() {
synchronized(this) { synchronized(this) {
if (scheduler == null) { if (scheduler == null) {
scheduler = defaultScheduler() scheduler = Executors.newSingleThreadScheduledExecutor(DefaultThreadFactory("FlowMonitor"))
shutdownScheduler = true shutdownScheduler = true
} }
scheduler!!.scheduleAtFixedRate({ logFlowsWaitingForParty() }, 0, monitoringPeriod.toMillis(), TimeUnit.MILLISECONDS) scheduler!!.scheduleAtFixedRate({ logFlowsWaitingForParty() }, 0, monitoringPeriod.toMillis(), TimeUnit.MILLISECONDS)

View File

@ -25,6 +25,7 @@ import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.node.services.vault.toStateRef
import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
@ -157,13 +158,7 @@ class PersistentUniquenessProvider(val clock: Clock, val database: CordaPersiste
toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) }, toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) },
fromPersistentEntity = { fromPersistentEntity = {
//TODO null check will become obsolete after making DB/JPA columns not nullable //TODO null check will become obsolete after making DB/JPA columns not nullable
val txId = it.id.txId Pair(it.id.toStateRef(), SecureHash.create(it.consumingTxHash))
val index = it.id.index
Pair(
StateRef(txhash = SecureHash.create(txId), index = index),
SecureHash.create(it.consumingTxHash)
)
}, },
toPersistentEntity = { (txHash, index): StateRef, id: SecureHash -> toPersistentEntity = { (txHash, index): StateRef, id: SecureHash ->
CommittedState( CommittedState(

View File

@ -3,28 +3,65 @@ package net.corda.node.services.vault
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.Strand
import net.corda.core.CordaRuntimeException import net.corda.core.CordaRuntimeException
import net.corda.core.contracts.* import net.corda.core.contracts.Amount
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.FungibleAsset
import net.corda.core.contracts.FungibleState
import net.corda.core.contracts.Issued
import net.corda.core.contracts.OwnableState
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionState
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.containsAny import net.corda.core.crypto.containsAny
import net.corda.core.flows.HospitalizeFlowException import net.corda.core.flows.HospitalizeFlowException
import net.corda.core.internal.* import net.corda.core.internal.ThreadBox
import net.corda.core.internal.TransactionDeserialisationException
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.tee
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.node.ServicesForResolution
import net.corda.core.node.StatesToRecord import net.corda.core.node.StatesToRecord
import net.corda.core.node.services.* import net.corda.core.node.services.KeyManagementService
import net.corda.core.node.services.Vault.ConstraintInfo.Companion.constraintInfo import net.corda.core.node.services.StatesNotAvailableException
import net.corda.core.node.services.vault.* import net.corda.core.node.services.Vault
import net.corda.core.node.services.VaultQueryException
import net.corda.core.node.services.VaultService
import net.corda.core.node.services.queryBy
import net.corda.core.node.services.vault.DEFAULT_PAGE_NUM
import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE
import net.corda.core.node.services.vault.PageSpecification
import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.node.services.vault.Sort
import net.corda.core.node.services.vault.SortAttribute
import net.corda.core.node.services.vault.builder
import net.corda.core.observable.internal.OnResilientSubscribe import net.corda.core.observable.internal.OnResilientSubscribe
import net.corda.core.schemas.PersistentStateRef import net.corda.core.schemas.PersistentStateRef
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.transactions.* import net.corda.core.transactions.ContractUpgradeWireTransaction
import net.corda.core.utilities.* import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.FullTransaction
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.core.utilities.toNonEmptySet
import net.corda.core.utilities.trace
import net.corda.node.internal.NodeServicesForResolution
import net.corda.node.services.api.SchemaService import net.corda.node.services.api.SchemaService
import net.corda.node.services.api.VaultServiceInternal import net.corda.node.services.api.VaultServiceInternal
import net.corda.node.services.schema.PersistentStateService import net.corda.node.services.schema.PersistentStateService
import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.nodeapi.internal.persistence.* import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.bufferUntilDatabaseCommit
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
import net.corda.nodeapi.internal.persistence.currentDBSession
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
import org.hibernate.Session import org.hibernate.Session
import org.hibernate.query.Query
import rx.Observable import rx.Observable
import rx.exceptions.OnErrorNotImplementedException import rx.exceptions.OnErrorNotImplementedException
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
@ -32,9 +69,11 @@ import java.security.PublicKey
import java.sql.SQLException import java.sql.SQLException
import java.time.Clock import java.time.Clock
import java.time.Instant import java.time.Instant
import java.util.* import java.util.Arrays
import java.util.UUID
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArraySet import java.util.concurrent.CopyOnWriteArraySet
import java.util.stream.Stream
import javax.persistence.PersistenceException import javax.persistence.PersistenceException
import javax.persistence.Tuple import javax.persistence.Tuple
import javax.persistence.criteria.CriteriaBuilder import javax.persistence.criteria.CriteriaBuilder
@ -54,9 +93,9 @@ import javax.persistence.criteria.Root
class NodeVaultService( class NodeVaultService(
private val clock: Clock, private val clock: Clock,
private val keyManagementService: KeyManagementService, private val keyManagementService: KeyManagementService,
private val servicesForResolution: ServicesForResolution, private val servicesForResolution: NodeServicesForResolution,
private val database: CordaPersistence, private val database: CordaPersistence,
private val schemaService: SchemaService, schemaService: SchemaService,
private val appClassloader: ClassLoader private val appClassloader: ClassLoader
) : SingletonSerializeAsToken(), VaultServiceInternal { ) : SingletonSerializeAsToken(), VaultServiceInternal {
companion object { companion object {
@ -196,7 +235,7 @@ class NodeVaultService(
if (lockId != null) { if (lockId != null) {
lockId = null lockId = null
lockUpdateTime = clock.instant() lockUpdateTime = clock.instant()
log.trace("Releasing soft lock on consumed state: $stateRef") log.trace { "Releasing soft lock on consumed state: $stateRef" }
} }
session.save(state) session.save(state)
} }
@ -227,7 +266,7 @@ class NodeVaultService(
} }
// we are not inside a flow, we are most likely inside a CordaService; // we are not inside a flow, we are most likely inside a CordaService;
// we will expose, by default, subscribing of -non unsubscribing- rx.Observers to rawUpdates. // we will expose, by default, subscribing of -non unsubscribing- rx.Observers to rawUpdates.
return _rawUpdatesPublisher.resilientOnError() _rawUpdatesPublisher.resilientOnError()
} }
override val updates: Observable<Vault.Update<ContractState>> override val updates: Observable<Vault.Update<ContractState>>
@ -639,7 +678,23 @@ class NodeVaultService(
@Throws(VaultQueryException::class) @Throws(VaultQueryException::class)
override fun <T : ContractState> _queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): Vault.Page<T> { override fun <T : ContractState> _queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): Vault.Page<T> {
try { try {
return _queryBy(criteria, paging, sorting, contractStateType, false) // We decrement by one if the client requests MAX_VALUE, assuming they can not notice this because they don't have enough memory
// to request MAX_VALUE states at once.
val validPaging = if (paging.pageSize == Integer.MAX_VALUE) {
paging.copy(pageSize = Integer.MAX_VALUE - 1)
} else {
checkVaultQuery(paging.pageSize >= 1) { "Page specification: invalid page size ${paging.pageSize} [minimum is 1]" }
paging
}
if (!validPaging.isDefault) {
checkVaultQuery(validPaging.pageNumber >= DEFAULT_PAGE_NUM) {
"Page specification: invalid page number ${validPaging.pageNumber} [page numbers start from $DEFAULT_PAGE_NUM]"
}
}
log.debug { "Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $validPaging, sorting: $sorting" }
return database.transaction {
queryBy(criteria, validPaging, sorting, contractStateType)
}
} catch (e: VaultQueryException) { } catch (e: VaultQueryException) {
throw e throw e
} catch (e: Exception) { } catch (e: Exception) {
@ -647,100 +702,90 @@ class NodeVaultService(
} }
} }
@Throws(VaultQueryException::class) private fun <T : ContractState> queryBy(criteria: QueryCriteria,
private fun <T : ContractState> _queryBy(criteria: QueryCriteria, paging_: PageSpecification, sorting: Sort, contractStateType: Class<out T>, skipPagingChecks: Boolean): Vault.Page<T> { paging: PageSpecification,
// We decrement by one if the client requests MAX_PAGE_SIZE, assuming they can not notice this because they don't have enough memory sorting: Sort,
// to request `MAX_PAGE_SIZE` states at once. contractStateType: Class<out T>): Vault.Page<T> {
val paging = if (paging_.pageSize == Integer.MAX_VALUE) { // calculate total results where a page specification has been defined
paging_.copy(pageSize = Integer.MAX_VALUE - 1) val totalStatesAvailable = if (paging.isDefault) -1 else queryTotalStateCount(criteria, contractStateType)
} else {
paging_ val (query, stateTypes) = createQuery(criteria, contractStateType, sorting)
query.setResultWindow(paging)
val statesMetadata: MutableList<Vault.StateMetadata> = mutableListOf()
val otherResults: MutableList<Any> = mutableListOf()
query.resultStream(paging).use { results ->
results.forEach { result ->
val result0 = result[0]
if (result0 is VaultSchemaV1.VaultStates) {
statesMetadata.add(result0.toStateMetadata())
} else {
log.debug { "OtherResults: ${Arrays.toString(result.toArray())}" }
otherResults.addAll(result.toArray().asList())
}
}
} }
log.debug { "Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $paging, sorting: $sorting" }
return database.transaction {
// calculate total results where a page specification has been defined
var totalStates = -1L
if (!skipPagingChecks && !paging.isDefault) {
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val results = _queryBy(criteria.and(countCriteria), PageSpecification(), Sort(emptyList()), contractStateType, true) // only skip pagination checks for total results count query
totalStates = results.otherResults.last() as Long
}
val session = getSession() val states: List<StateAndRef<T>> = servicesForResolution.loadStates(
statesMetadata.mapTo(LinkedHashSet()) { it.ref },
ArrayList()
)
val criteriaQuery = criteriaBuilder.createQuery(Tuple::class.java) return Vault.Page(states, statesMetadata, totalStatesAvailable, stateTypes, otherResults)
val queryRootVaultStates = criteriaQuery.from(VaultSchemaV1.VaultStates::class.java) }
// TODO: revisit (use single instance of parser for all queries)
val criteriaParser = HibernateQueryCriteriaParser(contractStateType, contractStateTypeMappings, criteriaBuilder, criteriaQuery, queryRootVaultStates)
// parse criteria and build where predicates
criteriaParser.parse(criteria, sorting)
// prepare query for execution
val query = session.createQuery(criteriaQuery)
// pagination checks
if (!skipPagingChecks && !paging.isDefault) {
// pagination
if (paging.pageNumber < DEFAULT_PAGE_NUM) throw VaultQueryException("Page specification: invalid page number ${paging.pageNumber} [page numbers start from $DEFAULT_PAGE_NUM]")
if (paging.pageSize < 1) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [minimum is 1]")
if (paging.pageSize > MAX_PAGE_SIZE) throw VaultQueryException("Page specification: invalid page size ${paging.pageSize} [maximum is $MAX_PAGE_SIZE]")
}
// For both SQLServer and PostgresSQL, firstResult must be >= 0. So we set a floor at 0.
// TODO: This is a catch-all solution. But why is the default pageNumber set to be -1 in the first place?
// Even if we set the default pageNumber to be 1 instead, that may not cover the non-default cases.
// So the floor may be necessary anyway.
query.firstResult = maxOf(0, (paging.pageNumber - 1) * paging.pageSize)
val pageSize = paging.pageSize + 1
query.maxResults = if (pageSize > 0) pageSize else Integer.MAX_VALUE // detection too many results, protected against overflow
// execution
val results = query.resultList
private fun <R> Query<R>.resultStream(paging: PageSpecification): Stream<R> {
return if (paging.isDefault) {
val allResults = resultList
// final pagination check (fail-fast on too many results when no pagination specified) // final pagination check (fail-fast on too many results when no pagination specified)
if (!skipPagingChecks && paging.isDefault && results.size > DEFAULT_PAGE_SIZE) { checkVaultQuery(allResults.size != paging.pageSize + 1) {
throw VaultQueryException("There are ${results.size} results, which exceeds the limit of $DEFAULT_PAGE_SIZE for queries that do not specify paging. In order to retrieve these results, provide a `PageSpecification(pageNumber, pageSize)` to the method invoked.") "There are more results than the limit of $DEFAULT_PAGE_SIZE for queries that do not specify paging. " +
"In order to retrieve these results, provide a PageSpecification to the method invoked."
} }
val statesAndRefs: MutableList<StateAndRef<T>> = mutableListOf() allResults.stream()
val statesMeta: MutableList<Vault.StateMetadata> = mutableListOf() } else {
val otherResults: MutableList<Any> = mutableListOf() stream()
val stateRefs = mutableSetOf<StateRef>()
results.asSequence()
.forEachIndexed { index, result ->
if (result[0] is VaultSchemaV1.VaultStates) {
if (!paging.isDefault && index == paging.pageSize) // skip last result if paged
return@forEachIndexed
val vaultState = result[0] as VaultSchemaV1.VaultStates
val stateRef = StateRef(SecureHash.create(vaultState.stateRef!!.txId), vaultState.stateRef!!.index)
stateRefs.add(stateRef)
statesMeta.add(Vault.StateMetadata(stateRef,
vaultState.contractStateClassName,
vaultState.recordedTime,
vaultState.consumedTime,
vaultState.stateStatus,
vaultState.notary,
vaultState.lockId,
vaultState.lockUpdateTime,
vaultState.relevancyStatus,
constraintInfo(vaultState.constraintType, vaultState.constraintData)
))
} else {
// TODO: improve typing of returned other results
log.debug { "OtherResults: ${Arrays.toString(result.toArray())}" }
otherResults.addAll(result.toArray().asList())
}
}
if (stateRefs.isNotEmpty())
statesAndRefs.addAll(uncheckedCast(servicesForResolution.loadStates(stateRefs)))
Vault.Page(states = statesAndRefs, statesMetadata = statesMeta, stateTypes = criteriaParser.stateTypes, totalStatesAvailable = totalStates, otherResults = otherResults)
} }
} }
private fun Query<*>.setResultWindow(paging: PageSpecification) {
if (paging.isDefault) {
// For both SQLServer and PostgresSQL, firstResult must be >= 0.
firstResult = 0
// Peek ahead and see if there are more results in case pagination should be done
maxResults = paging.pageSize + 1
} else {
firstResult = (paging.pageNumber - 1) * paging.pageSize
maxResults = paging.pageSize
}
}
private fun <T : ContractState> queryTotalStateCount(baseCriteria: QueryCriteria, contractStateType: Class<out T>): Long {
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val criteria = baseCriteria.and(countCriteria)
val (query) = createQuery(criteria, contractStateType, null)
val results = query.resultList
return results.last().toArray().last() as Long
}
private fun <T : ContractState> createQuery(criteria: QueryCriteria,
contractStateType: Class<out T>,
sorting: Sort?): Pair<Query<Tuple>, Vault.StateStatus> {
val criteriaQuery = criteriaBuilder.createQuery(Tuple::class.java)
val criteriaParser = HibernateQueryCriteriaParser(
contractStateType,
contractStateTypeMappings,
criteriaBuilder,
criteriaQuery,
criteriaQuery.from(VaultSchemaV1.VaultStates::class.java)
)
criteriaParser.parse(criteria, sorting)
val query = getSession().createQuery(criteriaQuery)
return Pair(query, criteriaParser.stateTypes)
}
/** /**
* Returns a [DataFeed] containing the results of the provided query, along with the associated observable, containing any subsequent updates. * Returns a [DataFeed] containing the results of the provided query, along with the associated observable, containing any subsequent updates.
* *
@ -775,6 +820,12 @@ class NodeVaultService(
} }
} }
private inline fun checkVaultQuery(value: Boolean, lazyMessage: () -> Any) {
if (!value) {
throw VaultQueryException(lazyMessage().toString())
}
}
private fun <T : ContractState> filterContractStates(update: Vault.Update<T>, contractStateType: Class<out T>) = private fun <T : ContractState> filterContractStates(update: Vault.Update<T>, contractStateType: Class<out T>) =
update.copy(consumed = filterByContractState(contractStateType, update.consumed), update.copy(consumed = filterByContractState(contractStateType, update.consumed),
produced = filterByContractState(contractStateType, update.produced)) produced = filterByContractState(contractStateType, update.produced))
@ -802,6 +853,7 @@ class NodeVaultService(
} }
private fun getSession() = database.currentOrNew().session private fun getSession() = database.currentOrNew().session
/** /**
* Derive list from existing vault states and then incrementally update using vault observables * Derive list from existing vault states and then incrementally update using vault observables
*/ */

View File

@ -2,7 +2,9 @@ package net.corda.node.services.vault
import net.corda.core.contracts.ContractState import net.corda.core.contracts.ContractState
import net.corda.core.contracts.MAX_ISSUER_REF_SIZE import net.corda.core.contracts.MAX_ISSUER_REF_SIZE
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.UniqueIdentifier import net.corda.core.contracts.UniqueIdentifier
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.toStringShort import net.corda.core.crypto.toStringShort
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party import net.corda.core.identity.Party
@ -192,3 +194,19 @@ object VaultSchemaV1 : MappedSchema(
) : IndirectStatePersistable<PersistentStateRefAndKey> ) : IndirectStatePersistable<PersistentStateRefAndKey>
} }
fun PersistentStateRef.toStateRef(): StateRef = StateRef(SecureHash.create(txId), index)
fun VaultSchemaV1.VaultStates.toStateMetadata(): Vault.StateMetadata {
return Vault.StateMetadata(
stateRef!!.toStateRef(),
contractStateClassName,
recordedTime,
consumedTime,
stateStatus,
notary,
lockId,
lockUpdateTime,
relevancyStatus,
Vault.ConstraintInfo.constraintInfo(constraintType, constraintData)
)
}

View File

@ -21,6 +21,7 @@ import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.node.services.vault.toStateRef
import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import java.security.PublicKey import java.security.PublicKey
@ -41,6 +42,8 @@ class BFTSmartNotaryService(
) : NotaryService() { ) : NotaryService() {
companion object { companion object {
private val log = contextLogger() private val log = contextLogger()
@Suppress("unused") // Used by NotaryLoader via reflection
@JvmStatic @JvmStatic
val serializationFilter val serializationFilter
get() = { clazz: Class<*> -> get() = { clazz: Class<*> ->
@ -147,12 +150,7 @@ class BFTSmartNotaryService(
toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) }, toPersistentEntityKey = { PersistentStateRef(it.txhash.toString(), it.index) },
fromPersistentEntity = { fromPersistentEntity = {
//TODO null check will become obsolete after making DB/JPA columns not nullable //TODO null check will become obsolete after making DB/JPA columns not nullable
val txId = it.id.txId Pair(it.id.toStateRef(), SecureHash.create(it.consumingTxHash))
val index = it.id.index
Pair(
StateRef(txhash = SecureHash.create(txId), index = index),
SecureHash.create(it.consumingTxHash)
)
}, },
toPersistentEntity = { (txHash, index): StateRef, id: SecureHash -> toPersistentEntity = { (txHash, index): StateRef, id: SecureHash ->
CommittedState( CommittedState(

View File

@ -24,6 +24,7 @@ import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.node.services.vault.toStateRef
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.notary.common.InternalResult import net.corda.notary.common.InternalResult
@ -142,10 +143,6 @@ class JPAUniquenessProvider(
fun encodeStateRef(s: StateRef): PersistentStateRef { fun encodeStateRef(s: StateRef): PersistentStateRef {
return PersistentStateRef(s.txhash.toString(), s.index) return PersistentStateRef(s.txhash.toString(), s.index)
} }
fun decodeStateRef(s: PersistentStateRef): StateRef {
return StateRef(txhash = SecureHash.create(s.txId), index = s.index)
}
} }
/** /**
@ -215,15 +212,15 @@ class JPAUniquenessProvider(
committedStates.addAll(existing) committedStates.addAll(existing)
} }
return committedStates.map { return committedStates.associate {
val stateRef = StateRef(txhash = SecureHash.create(it.id.txId), index = it.id.index) val stateRef = it.id.toStateRef()
val consumingTxId = SecureHash.create(it.consumingTxHash) val consumingTxId = SecureHash.create(it.consumingTxHash)
if (stateRef in references) { if (stateRef in references) {
stateRef to StateConsumptionDetails(consumingTxId.reHash(), type = StateConsumptionDetails.ConsumedStateType.REFERENCE_INPUT_STATE) stateRef to StateConsumptionDetails(consumingTxId.reHash(), type = StateConsumptionDetails.ConsumedStateType.REFERENCE_INPUT_STATE)
} else { } else {
stateRef to StateConsumptionDetails(consumingTxId.reHash()) stateRef to StateConsumptionDetails(consumingTxId.reHash())
} }
}.toMap() }
} }
private fun<T> withRetry(block: () -> T): T { private fun<T> withRetry(block: () -> T): T {

View File

@ -0,0 +1 @@
net.corda.node.services.messaging.NodeOpenSSLContextFactory

View File

@ -0,0 +1 @@
net.corda.node.services.messaging.NodeSSLContextFactory

View File

@ -124,7 +124,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
bobNode.internals.disableDBCloseOnStop() bobNode.internals.disableDBCloseOnStop()
bobNode.database.transaction { bobNode.database.transaction {
VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, cashIssuer) VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, cashIssuer, atMostThisManyStates = 10)
} }
val alicesFakePaper = aliceNode.database.transaction { val alicesFakePaper = aliceNode.database.transaction {
@ -233,7 +233,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) {
val issuer = bank.ref(1, 2, 3) val issuer = bank.ref(1, 2, 3)
bobNode.database.transaction { bobNode.database.transaction {
VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, 10, issuer) VaultFiller(bobNode.services, dummyNotary, notary, ::Random).fillWithSomeTestCash(2000.DOLLARS, bankNode.services, 3, issuer, atMostThisManyStates = 10)
} }
val alicesFakePaper = aliceNode.database.transaction { val alicesFakePaper = aliceNode.database.transaction {
fillUpForSeller(false, issuer, alice, fillUpForSeller(false, issuer, alice,

View File

@ -28,12 +28,14 @@ import net.corda.finance.schemas.CashSchemaV1
import net.corda.finance.test.SampleCashSchemaV1 import net.corda.finance.test.SampleCashSchemaV1
import net.corda.finance.test.SampleCashSchemaV2 import net.corda.finance.test.SampleCashSchemaV2
import net.corda.finance.test.SampleCashSchemaV3 import net.corda.finance.test.SampleCashSchemaV3
import net.corda.node.internal.NodeServicesForResolution
import net.corda.node.services.api.WritableTransactionStorage import net.corda.node.services.api.WritableTransactionStorage
import net.corda.node.services.schema.ContractStateAndRef import net.corda.node.services.schema.ContractStateAndRef
import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.schema.NodeSchemaService
import net.corda.node.services.schema.PersistentStateService import net.corda.node.services.schema.PersistentStateService
import net.corda.node.services.vault.NodeVaultService import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSchemaV1 import net.corda.node.services.vault.VaultSchemaV1
import net.corda.node.services.vault.toStateRef
import net.corda.node.testing.DummyFungibleContract import net.corda.node.testing.DummyFungibleContract
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseConfig
@ -48,7 +50,6 @@ import net.corda.testing.internal.vault.VaultFiller
import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.assertj.core.api.Assertions.`in`
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
import org.hibernate.SessionFactory import org.hibernate.SessionFactory
@ -122,7 +123,14 @@ class HibernateConfigurationTest {
services = object : MockServices(cordappPackages, BOB_NAME, mock<IdentityService>().also { services = object : MockServices(cordappPackages, BOB_NAME, mock<IdentityService>().also {
doReturn(null).whenever(it).verifyAndRegisterIdentity(argThat { name == BOB_NAME }) doReturn(null).whenever(it).verifyAndRegisterIdentity(argThat { name == BOB_NAME })
}, generateKeyPair(), dummyNotary.keyPair) { }, generateKeyPair(), dummyNotary.keyPair) {
override val vaultService = NodeVaultService(Clock.systemUTC(), keyManagementService, servicesForResolution, database, schemaService, cordappClassloader).apply { start() } override val vaultService = NodeVaultService(
Clock.systemUTC(),
keyManagementService,
servicesForResolution as NodeServicesForResolution,
database,
schemaService,
cordappClassloader
).apply { start() }
override fun recordTransactions(statesToRecord: StatesToRecord, txs: Iterable<SignedTransaction>) { override fun recordTransactions(statesToRecord: StatesToRecord, txs: Iterable<SignedTransaction>) {
for (stx in txs) { for (stx in txs) {
(validatedTransactions as WritableTransactionStorage).addTransaction(stx) (validatedTransactions as WritableTransactionStorage).addTransaction(stx)
@ -183,7 +191,7 @@ class HibernateConfigurationTest {
// execute query // execute query
val queryResults = entityManager.createQuery(criteriaQuery).resultList val queryResults = entityManager.createQuery(criteriaQuery).resultList
val coins = queryResults.map { val coins = queryResults.map {
services.loadState(toStateRef(it.stateRef!!)).data services.loadState(it.stateRef!!.toStateRef()).data
}.sumCash() }.sumCash()
assertThat(coins.toDecimal() >= BigDecimal("50.00")) assertThat(coins.toDecimal() >= BigDecimal("50.00"))
} }
@ -739,7 +747,7 @@ class HibernateConfigurationTest {
val queryResults = entityManager.createQuery(criteriaQuery).resultList val queryResults = entityManager.createQuery(criteriaQuery).resultList
queryResults.forEach { queryResults.forEach {
val cashState = services.loadState(toStateRef(it.stateRef!!)).data as Cash.State val cashState = services.loadState(it.stateRef!!.toStateRef()).data as Cash.State
println("${it.stateRef} with owner: ${cashState.owner.owningKey.toBase58String()}") println("${it.stateRef} with owner: ${cashState.owner.owningKey.toBase58String()}")
} }
@ -823,7 +831,7 @@ class HibernateConfigurationTest {
// execute query // execute query
val queryResults = entityManager.createQuery(criteriaQuery).resultList val queryResults = entityManager.createQuery(criteriaQuery).resultList
queryResults.forEach { queryResults.forEach {
val cashState = services.loadState(toStateRef(it.stateRef!!)).data as Cash.State val cashState = services.loadState(it.stateRef!!.toStateRef()).data as Cash.State
println("${it.stateRef} with owner ${cashState.owner.owningKey.toBase58String()} and participants ${cashState.participants.map { it.owningKey.toBase58String() }}") println("${it.stateRef} with owner ${cashState.owner.owningKey.toBase58String()} and participants ${cashState.participants.map { it.owningKey.toBase58String() }}")
} }
@ -961,10 +969,6 @@ class HibernateConfigurationTest {
} }
} }
private fun toStateRef(pStateRef: PersistentStateRef): StateRef {
return StateRef(SecureHash.create(pStateRef.txId), pStateRef.index)
}
@Test(timeout=300_000) @Test(timeout=300_000)
fun `schema change`() { fun `schema change`() {
fun createNewDB(schemas: Set<MappedSchema>, initialiseSchema: Boolean = true): CordaPersistence { fun createNewDB(schemas: Set<MappedSchema>, initialiseSchema: Boolean = true): CordaPersistence {

View File

@ -244,7 +244,6 @@ class FlowSoftLocksTests {
100.DOLLARS, 100.DOLLARS,
bankNode.services, bankNode.services,
thisManyStates, thisManyStates,
thisManyStates,
cashIssuer cashIssuer
) )
} }

View File

@ -20,14 +20,13 @@ import net.corda.finance.*
import net.corda.finance.contracts.CommercialPaper import net.corda.finance.contracts.CommercialPaper
import net.corda.finance.contracts.Commodity import net.corda.finance.contracts.Commodity
import net.corda.finance.contracts.DealState import net.corda.finance.contracts.DealState
import net.corda.finance.workflows.asset.selection.AbstractCashSelection
import net.corda.finance.contracts.asset.Cash import net.corda.finance.contracts.asset.Cash
import net.corda.finance.schemas.CashSchemaV1 import net.corda.finance.schemas.CashSchemaV1
import net.corda.finance.schemas.CashSchemaV1.PersistentCashState
import net.corda.finance.schemas.CommercialPaperSchemaV1 import net.corda.finance.schemas.CommercialPaperSchemaV1
import net.corda.finance.test.SampleCashSchemaV2 import net.corda.finance.test.SampleCashSchemaV2
import net.corda.finance.test.SampleCashSchemaV3 import net.corda.finance.test.SampleCashSchemaV3
import net.corda.finance.workflows.CommercialPaperUtils import net.corda.finance.workflows.CommercialPaperUtils
import net.corda.finance.workflows.asset.selection.AbstractCashSelection
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.nodeapi.internal.persistence.DatabaseTransaction
@ -197,8 +196,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
} }
protected fun consumeCash(amount: Amount<Currency>) = vaultFiller.consumeCash(amount, CHARLIE) protected fun consumeCash(amount: Amount<Currency>) = vaultFiller.consumeCash(amount, CHARLIE)
private fun setUpDb(_database: CordaPersistence, delay: Long = 0) {
_database.transaction { private fun setUpDb(database: CordaPersistence, delay: Long = 0) {
database.transaction {
// create new states // create new states
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 10, DUMMY_CASH_ISSUER) vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 10, DUMMY_CASH_ISSUER)
val linearStatesXYZ = vaultFiller.fillWithSomeTestLinearStates(1, "XYZ") val linearStatesXYZ = vaultFiller.fillWithSomeTestLinearStates(1, "XYZ")
@ -444,7 +444,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.state.data.linearNumber }.sortedBy { it.ref.txhash }.sortedBy { it.ref.index }).isEqualTo(allStates) Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.state.data.linearNumber }.sortedBy { it.ref.txhash }.sortedBy { it.ref.index }).isEqualTo(allStates)
} }
(1..3).forEach { repeat(3) {
val newAllStates = vaultService.queryBy<DummyLinearContract.State>(sorting = sorting, criteria = criteria).states val newAllStates = vaultService.queryBy<DummyLinearContract.State>(sorting = sorting, criteria = criteria).states
assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates) assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates)
assertThat(newAllStates).containsExactlyElementsOf(allStates) assertThat(newAllStates).containsExactlyElementsOf(allStates)
@ -485,7 +485,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.ref.txhash }.sortedByDescending { it.ref.index }).isEqualTo(allStates) Sort.Direction.DESC -> assertThat(allStates.sortedByDescending { it.ref.txhash }.sortedByDescending { it.ref.index }).isEqualTo(allStates)
} }
(1..3).forEach { repeat(3) {
val newAllStates = vaultService.queryBy<DummyLinearContract.State>(sorting = sorting, criteria = criteria).states val newAllStates = vaultService.queryBy<DummyLinearContract.State>(sorting = sorting, criteria = criteria).states
assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates) assertThat(newAllStates.groupBy(StateAndRef<*>::ref)).hasSameSizeAs(allStates)
assertThat(newAllStates).containsExactlyElementsOf(allStates) assertThat(newAllStates).containsExactlyElementsOf(allStates)
@ -638,7 +638,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
} }
val sorted = results.states.sortedBy { it.ref.toString() } val sorted = results.states.sortedBy { it.ref.toString() }
assertThat(results.states).isEqualTo(sorted) assertThat(results.states).isEqualTo(sorted)
assertThat(results.states).allSatisfy { !consumed.contains(it.ref.txhash) } assertThat(results.states).allSatisfy { assertThat(consumed).doesNotContain(it.ref.txhash) }
} }
} }
@ -1537,7 +1537,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789")) vaultFiller.fillWithSomeTestDeals(listOf("123", "456", "789"))
// count fungible assets // count fungible assets
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count) val countCriteria = VaultCustomQueryCriteria(count)
val fungibleStateCount = vaultService.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long val fungibleStateCount = vaultService.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
assertThat(fungibleStateCount).isEqualTo(10L) assertThat(fungibleStateCount).isEqualTo(10L)
@ -1563,7 +1563,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() }
// count fungible assets // count fungible assets
val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) val countCriteria = VaultCustomQueryCriteria(count, Vault.StateStatus.ALL)
val fungibleStateCount = vaultService.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long val fungibleStateCount = vaultService.queryBy<FungibleAsset<*>>(countCriteria).otherResults.single() as Long
assertThat(fungibleStateCount).isEqualTo(10L) assertThat(fungibleStateCount).isEqualTo(10L)
@ -1583,7 +1583,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// UNCONSUMED states (default) // UNCONSUMED states (default)
// count fungible assets // count fungible assets
val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED) val countCriteriaUnconsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED)
val fungibleStateCountUnconsumed = vaultService.queryBy<FungibleAsset<*>>(countCriteriaUnconsumed).otherResults.single() as Long val fungibleStateCountUnconsumed = vaultService.queryBy<FungibleAsset<*>>(countCriteriaUnconsumed).otherResults.single() as Long
assertThat(fungibleStateCountUnconsumed.toInt()).isEqualTo(10 - cashUpdates.consumed.size + cashUpdates.produced.size) assertThat(fungibleStateCountUnconsumed.toInt()).isEqualTo(10 - cashUpdates.consumed.size + cashUpdates.produced.size)
@ -1598,7 +1598,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// CONSUMED states // CONSUMED states
// count fungible assets // count fungible assets
val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED) val countCriteriaConsumed = VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED)
val fungibleStateCountConsumed = vaultService.queryBy<FungibleAsset<*>>(countCriteriaConsumed).otherResults.single() as Long val fungibleStateCountConsumed = vaultService.queryBy<FungibleAsset<*>>(countCriteriaConsumed).otherResults.single() as Long
assertThat(fungibleStateCountConsumed.toInt()).isEqualTo(cashUpdates.consumed.size) assertThat(fungibleStateCountConsumed.toInt()).isEqualTo(cashUpdates.consumed.size)
@ -1622,7 +1622,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val start = TODAY val start = TODAY
val end = TODAY.plus(30, ChronoUnit.DAYS) val end = TODAY.plus(30, ChronoUnit.DAYS)
val recordedBetweenExpression = TimeCondition( val recordedBetweenExpression = TimeCondition(
QueryCriteria.TimeInstantType.RECORDED, TimeInstantType.RECORDED,
ColumnPredicate.Between(start, end)) ColumnPredicate.Between(start, end))
val criteria = VaultQueryCriteria(timeCondition = recordedBetweenExpression) val criteria = VaultQueryCriteria(timeCondition = recordedBetweenExpression)
val results = vaultService.queryBy<ContractState>(criteria) val results = vaultService.queryBy<ContractState>(criteria)
@ -1632,7 +1632,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// Future // Future
val startFuture = TODAY.plus(1, ChronoUnit.DAYS) val startFuture = TODAY.plus(1, ChronoUnit.DAYS)
val recordedBetweenExpressionFuture = TimeCondition( val recordedBetweenExpressionFuture = TimeCondition(
QueryCriteria.TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end)) TimeInstantType.RECORDED, ColumnPredicate.Between(startFuture, end))
val criteriaFuture = VaultQueryCriteria(timeCondition = recordedBetweenExpressionFuture) val criteriaFuture = VaultQueryCriteria(timeCondition = recordedBetweenExpressionFuture)
assertThat(vaultService.queryBy<ContractState>(criteriaFuture).states).isEmpty() assertThat(vaultService.queryBy<ContractState>(criteriaFuture).states).isEmpty()
} }
@ -1648,7 +1648,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
consumeCash(100.DOLLARS) consumeCash(100.DOLLARS)
val asOfDateTime = TODAY val asOfDateTime = TODAY
val consumedAfterExpression = TimeCondition( val consumedAfterExpression = TimeCondition(
QueryCriteria.TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime)) TimeInstantType.CONSUMED, ColumnPredicate.BinaryComparison(BinaryComparisonOperator.GREATER_THAN_OR_EQUAL, asOfDateTime))
val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED, val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED,
timeCondition = consumedAfterExpression) timeCondition = consumedAfterExpression)
val results = vaultService.queryBy<ContractState>(criteria) val results = vaultService.queryBy<ContractState>(criteria)
@ -1674,7 +1674,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// pagination: last page // pagination: last page
@Test(timeout=300_000) @Test(timeout=300_000)
fun `all states with paging specification - last`() { fun `all states with paging specification - last`() {
database.transaction { database.transaction {
vaultFiller.fillWithSomeTestCash(95.DOLLARS, notaryServices, 95, DUMMY_CASH_ISSUER) vaultFiller.fillWithSomeTestCash(95.DOLLARS, notaryServices, 95, DUMMY_CASH_ISSUER)
// Last page implies we need to perform a row count for the Query first, // Last page implies we need to perform a row count for the Query first,
@ -1705,6 +1705,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
} }
// pagination: invalid page size // pagination: invalid page size
@Suppress("INTEGER_OVERFLOW")
@Test(timeout=300_000) @Test(timeout=300_000)
fun `invalid page size`() { fun `invalid page size`() {
expectedEx.expect(VaultQueryException::class.java) expectedEx.expect(VaultQueryException::class.java)
@ -1712,8 +1713,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
database.transaction { database.transaction {
vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 100, DUMMY_CASH_ISSUER) vaultFiller.fillWithSomeTestCash(100.DOLLARS, notaryServices, 100, DUMMY_CASH_ISSUER)
@Suppress("EXPECTED_CONDITION") val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, Integer.MAX_VALUE + 1) // overflow = -2147483648
val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, @Suppress("INTEGER_OVERFLOW") Integer.MAX_VALUE + 1) // overflow = -2147483648
val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL)
vaultService.queryBy<ContractState>(criteria, paging = pagingSpec) vaultService.queryBy<ContractState>(criteria, paging = pagingSpec)
} }
@ -1723,7 +1723,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
@Test(timeout=300_000) @Test(timeout=300_000)
fun `pagination not specified but more than default results available`() { fun `pagination not specified but more than default results available`() {
expectedEx.expect(VaultQueryException::class.java) expectedEx.expect(VaultQueryException::class.java)
expectedEx.expectMessage("provide a `PageSpecification(pageNumber, pageSize)`") expectedEx.expectMessage("provide a PageSpecification")
database.transaction { database.transaction {
vaultFiller.fillWithSomeTestCash(201.DOLLARS, notaryServices, 201, DUMMY_CASH_ISSUER) vaultFiller.fillWithSomeTestCash(201.DOLLARS, notaryServices, 201, DUMMY_CASH_ISSUER)
@ -1781,9 +1781,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
println("$index : $any") println("$index : $any")
} }
assertThat(results.otherResults.size).isEqualTo(402) assertThat(results.otherResults.size).isEqualTo(402)
val instants = results.otherResults.filter { it is Instant }.map { it as Instant } val instants = results.otherResults.filterIsInstance<Instant>()
assertThat(instants).isSorted assertThat(instants).isSorted
val longs = results.otherResults.filter { it is Long }.map { it as Long } val longs = results.otherResults.filterIsInstance<Long>()
assertThat(longs.size).isEqualTo(201) assertThat(longs.size).isEqualTo(201)
assertThat(instants.size).isEqualTo(201) assertThat(instants.size).isEqualTo(201)
assertThat(longs.sum()).isEqualTo(20100L) assertThat(longs.sum()).isEqualTo(20100L)
@ -1911,8 +1911,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
fun `LinearStateQueryCriteria returns empty resultset without errors if there is an empty list after the 'in' clause`() { fun `LinearStateQueryCriteria returns empty resultset without errors if there is an empty list after the 'in' clause`() {
database.transaction { database.transaction {
val uid = UniqueIdentifier("999") val uid = UniqueIdentifier("999")
vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, uniqueIdentifier = uid) vaultFiller.fillWithSomeTestLinearStates(txCount = 1, uniqueIdentifier = uid)
vaultFiller.fillWithSomeTestLinearStates(numberToCreate = 1, externalId = "1234") vaultFiller.fillWithSomeTestLinearStates(txCount = 1, externalId = "1234")
val uuidCriteria = LinearStateQueryCriteria(uuid = listOf(uid.id)) val uuidCriteria = LinearStateQueryCriteria(uuid = listOf(uid.id))
val externalIdCriteria = LinearStateQueryCriteria(externalId = listOf("1234")) val externalIdCriteria = LinearStateQueryCriteria(externalId = listOf("1234"))
@ -2061,6 +2061,26 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
} }
} }
@Test(timeout = 300_000)
fun `unconsumed states which are globally unordered across multiple transactions sorted by custom attribute`() {
val linearNumbers = Array(2) { LongArray(2) }
// Make sure states from the same transaction are not given consecutive linear numbers.
linearNumbers[0][0] = 1L
linearNumbers[0][1] = 3L
linearNumbers[1][0] = 2L
linearNumbers[1][1] = 4L
val results = database.transaction {
vaultFiller.fillWithTestStates(txCount = 2, statesPerTx = 2) { participantsToUse, txIndex, stateIndex ->
DummyLinearContract.State(participants = participantsToUse, linearNumber = linearNumbers[txIndex][stateIndex])
}
val sortColumn = Sort.SortColumn(SortAttribute.Custom(DummyLinearStateSchemaV1.PersistentDummyLinearState::class.java, "linearNumber"))
vaultService.queryBy<DummyLinearContract.State>(VaultQueryCriteria(), sorting = Sort(setOf(sortColumn)))
}
assertThat(results.states.map { it.state.data.linearNumber }).isEqualTo(listOf(1L, 2L, 3L, 4L))
}
@Test(timeout=300_000) @Test(timeout=300_000)
fun `return consumed linear states for a given linear id`() { fun `return consumed linear states for a given linear id`() {
database.transaction { database.transaction {
@ -2390,7 +2410,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
services.recordTransactions(commercialPaper2) services.recordTransactions(commercialPaper2)
val ccyIndex = builder { CommercialPaperSchemaV1.PersistentCommercialPaperState::currency.equal(USD.currencyCode) } val ccyIndex = builder { CommercialPaperSchemaV1.PersistentCommercialPaperState::currency.equal(USD.currencyCode) }
val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) val criteria1 = VaultCustomQueryCriteria(ccyIndex)
val result = vaultService.queryBy<CommercialPaper.State>(criteria1) val result = vaultService.queryBy<CommercialPaper.State>(criteria1)
@ -2433,9 +2453,9 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val maturityIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::maturity.greaterThanOrEqual(TEST_TX_TIME + 30.days) val maturityIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::maturity.greaterThanOrEqual(TEST_TX_TIME + 30.days)
val faceValueIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::faceValue.greaterThanOrEqual(10000L) val faceValueIndex = CommercialPaperSchemaV1.PersistentCommercialPaperState::faceValue.greaterThanOrEqual(10000L)
val criteria1 = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) val criteria1 = VaultCustomQueryCriteria(ccyIndex)
val criteria2 = QueryCriteria.VaultCustomQueryCriteria(maturityIndex) val criteria2 = VaultCustomQueryCriteria(maturityIndex)
val criteria3 = QueryCriteria.VaultCustomQueryCriteria(faceValueIndex) val criteria3 = VaultCustomQueryCriteria(faceValueIndex)
vaultService.queryBy<CommercialPaper.State>(criteria1.and(criteria3).and(criteria2)) vaultService.queryBy<CommercialPaper.State>(criteria1.and(criteria3).and(criteria2))
} }
@ -2458,8 +2478,8 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
val generalCriteria = VaultQueryCriteria(Vault.StateStatus.ALL) val generalCriteria = VaultQueryCriteria(Vault.StateStatus.ALL)
val results = builder { val results = builder {
val currencyIndex = PersistentCashState::currency.equal(USD.currencyCode) val currencyIndex = CashSchemaV1.PersistentCashState::currency.equal(USD.currencyCode)
val quantityIndex = PersistentCashState::pennies.greaterThanOrEqual(10L) val quantityIndex = CashSchemaV1.PersistentCashState::pennies.greaterThanOrEqual(10L)
val customCriteria1 = VaultCustomQueryCriteria(currencyIndex) val customCriteria1 = VaultCustomQueryCriteria(currencyIndex)
val customCriteria2 = VaultCustomQueryCriteria(quantityIndex) val customCriteria2 = VaultCustomQueryCriteria(quantityIndex)
@ -2710,7 +2730,7 @@ abstract class VaultQueryTestsBase : VaultQueryParties {
// Enrich and override QueryCriteria with additional default attributes (such as soft locks) // Enrich and override QueryCriteria with additional default attributes (such as soft locks)
val enrichedCriteria = VaultQueryCriteria(contractStateTypes = setOf(DealState::class.java), // enrich val enrichedCriteria = VaultQueryCriteria(contractStateTypes = setOf(DealState::class.java), // enrich
softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())), softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(UUID.randomUUID())),
status = Vault.StateStatus.UNCONSUMED) // override status = Vault.StateStatus.UNCONSUMED) // override
// Sorting // Sorting
val sortAttribute = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF) val sortAttribute = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF)
@ -3056,7 +3076,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
assertThat(snapshot.states).hasSize(0) assertThat(snapshot.states).hasSize(0)
val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states
this.session.flush() this.session.flush()
vaultFiller.consumeLinearStates(states.toList()) vaultFiller.consumeStates(states)
updates updates
} }
@ -3079,7 +3099,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
assertThat(snapshot.states).hasSize(0) assertThat(snapshot.states).hasSize(0)
val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states
this.session.flush() this.session.flush()
vaultFiller.consumeLinearStates(states.toList()) vaultFiller.consumeStates(states)
updates updates
} }
@ -3102,7 +3122,7 @@ class VaultQueryTests : VaultQueryTestsBase(), VaultQueryParties by delegate {
assertThat(snapshot.states).hasSize(0) assertThat(snapshot.states).hasSize(0)
val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states val states = vaultFiller.fillWithSomeTestLinearAndDealStates(10).states
this.session.flush() this.session.flush()
vaultFiller.consumeLinearStates(states.toList()) vaultFiller.consumeStates(states)
updates updates
} }

View File

@ -10,7 +10,6 @@ import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.node.ServicesForResolution
import net.corda.core.node.services.KeyManagementService import net.corda.core.node.services.KeyManagementService
import net.corda.core.node.services.queryBy import net.corda.core.node.services.queryBy
import net.corda.core.node.services.vault.QueryCriteria.SoftLockingCondition import net.corda.core.node.services.vault.QueryCriteria.SoftLockingCondition
@ -29,6 +28,7 @@ import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.testing.core.singleIdentity import net.corda.testing.core.singleIdentity
import net.corda.testing.flows.registerCoreFlowFactory import net.corda.testing.flows.registerCoreFlowFactory
import net.corda.coretesting.internal.rigorousMock import net.corda.coretesting.internal.rigorousMock
import net.corda.node.internal.NodeServicesForResolution
import net.corda.testing.node.internal.InternalMockNetwork import net.corda.testing.node.internal.InternalMockNetwork
import net.corda.testing.node.internal.enclosedCordapp import net.corda.testing.node.internal.enclosedCordapp
import net.corda.testing.node.internal.startFlow import net.corda.testing.node.internal.startFlow
@ -86,7 +86,7 @@ class VaultSoftLockManagerTest {
private val mockNet = InternalMockNetwork(cordappsForAllNodes = listOf(enclosedCordapp()), defaultFactory = { args -> private val mockNet = InternalMockNetwork(cordappsForAllNodes = listOf(enclosedCordapp()), defaultFactory = { args ->
object : InternalMockNetwork.MockNode(args) { object : InternalMockNetwork.MockNode(args) {
override fun makeVaultService(keyManagementService: KeyManagementService, override fun makeVaultService(keyManagementService: KeyManagementService,
services: ServicesForResolution, services: NodeServicesForResolution,
database: CordaPersistence, database: CordaPersistence,
cordappLoader: CordappLoader): VaultServiceInternal { cordappLoader: CordappLoader): VaultServiceInternal {
val node = this val node = this

View File

@ -1,30 +1,50 @@
@file:Suppress("UNUSED_PARAMETER")
@file:JvmName("TestUtils") @file:JvmName("TestUtils")
@file:Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod", "LongParameterList")
package net.corda.testing.core package net.corda.testing.core
import net.corda.core.contracts.PartyAndReference import net.corda.core.contracts.PartyAndReference
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
import net.corda.core.crypto.* import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignatureScheme
import net.corda.core.crypto.toStringShort
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.toX500Name
import net.corda.core.internal.unspecifiedCountry import net.corda.core.internal.unspecifiedCountry
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.millis import net.corda.core.utilities.millis
import net.corda.core.utilities.minutes
import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA
import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.nodeapi.internal.createDevNodeCa import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.CertificateType import net.corda.nodeapi.internal.crypto.CertificateType
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames
import net.corda.coretesting.internal.DEV_ROOT_CA import org.bouncycastle.asn1.x509.CRLReason
import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.ExtensionsGenerator
import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.IssuingDistributionPoint
import org.bouncycastle.cert.jcajce.JcaX509CRLConverter
import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import java.math.BigInteger import java.math.BigInteger
import java.net.URI
import java.security.KeyPair import java.security.KeyPair
import java.security.PublicKey import java.security.PublicKey
import java.security.cert.X509CRL
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.time.Duration import java.time.Duration
import java.time.Instant import java.time.Instant
import java.util.*
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.fail import kotlin.test.fail
@ -109,6 +129,44 @@ fun getTestPartyAndCertificate(name: CordaX500Name, publicKey: PublicKey): Party
return getTestPartyAndCertificate(Party(name, publicKey)) return getTestPartyAndCertificate(Party(name, publicKey))
} }
fun createCRL(issuer: CertificateAndKeyPair,
revokedCerts: List<X509Certificate>,
issuingDistPoint: URI? = null,
thisUpdate: Instant = Instant.now(),
nextUpdate: Instant = thisUpdate + 5.minutes,
indirect: Boolean = false,
revocationDate: Instant = thisUpdate,
crlReason: Int = CRLReason.keyCompromise,
signatureAlgorithm: String = "SHA256withECDSA"): X509CRL {
val builder = JcaX509v2CRLBuilder(issuer.certificate.subjectX500Principal, Date.from(thisUpdate))
val extensionUtils = JcaX509ExtensionUtils()
builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(issuer.certificate))
// This is required and needs to match the certificate settings with respect to being indirect
builder.addExtension(
Extension.issuingDistributionPoint,
true,
IssuingDistributionPoint(
issuingDistPoint?.let { DistributionPointName(toGeneralNames(it.toString(), GeneralName.uniformResourceIdentifier)) },
indirect,
false
)
)
builder.setNextUpdate(Date.from(nextUpdate))
for (revokedCert in revokedCerts) {
val extensionsGenerator = ExtensionsGenerator()
extensionsGenerator.addExtension(Extension.reasonCode, false, CRLReason.lookup(crlReason))
// Certificate issuer is required for indirect CRL
extensionsGenerator.addExtension(
Extension.certificateIssuer,
true,
GeneralNames(GeneralName(revokedCert.issuerX500Principal.toX500Name()))
)
builder.addCRLEntry(revokedCert.serialNumber, Date.from(revocationDate), extensionsGenerator.generate())
}
val bcProvider = Crypto.findProvider("BC")
val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(bcProvider).build(issuer.keyPair.private)
return JcaX509CRLConverter().setProvider(bcProvider).getCRL(builder.build(signer))
}
private val count = AtomicInteger(0) private val count = AtomicInteger(0)
/** /**
@ -188,7 +246,6 @@ fun NodeInfo.singleIdentity(): Party = singleIdentityAndCert().party
* The above will test our expectation that the getWaitingFlows action was executed successfully considering * The above will test our expectation that the getWaitingFlows action was executed successfully considering
* that it may take a few hundreds of milliseconds for the flow state machine states to settle. * that it may take a few hundreds of milliseconds for the flow state machine states to settle.
*/ */
@Suppress("TooGenericExceptionCaught", "MagicNumber", "ComplexMethod")
fun <T> executeTest( fun <T> executeTest(
timeout: Duration, timeout: Duration,
cleanup: (() -> Unit)? = null, cleanup: (() -> Unit)? = null,

View File

@ -28,6 +28,7 @@ import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
import net.corda.node.internal.ServicesForResolutionImpl import net.corda.node.internal.ServicesForResolutionImpl
import net.corda.node.internal.NodeServicesForResolution
import net.corda.node.internal.cordapp.JarScanningCordappLoader import net.corda.node.internal.cordapp.JarScanningCordappLoader
import net.corda.node.services.api.* import net.corda.node.services.api.*
import net.corda.node.services.diagnostics.NodeDiagnosticsService import net.corda.node.services.diagnostics.NodeDiagnosticsService
@ -463,7 +464,14 @@ open class MockServices private constructor(
get() = ServicesForResolutionImpl(identityService, attachments, cordappProvider, networkParametersService, validatedTransactions) get() = ServicesForResolutionImpl(identityService, attachments, cordappProvider, networkParametersService, validatedTransactions)
internal fun makeVaultService(schemaService: SchemaService, database: CordaPersistence, cordappLoader: CordappLoader): VaultServiceInternal { internal fun makeVaultService(schemaService: SchemaService, database: CordaPersistence, cordappLoader: CordappLoader): VaultServiceInternal {
return NodeVaultService(clock, keyManagementService, servicesForResolution, database, schemaService, cordappLoader.appClassLoader).apply { start() } return NodeVaultService(
clock,
keyManagementService,
servicesForResolution as NodeServicesForResolution,
database,
schemaService,
cordappLoader.appClassLoader
).apply { start() }
} }
// This needs to be internal as MutableClassToInstanceMap is a guava type and shouldn't be part of our public API // This needs to be internal as MutableClassToInstanceMap is a guava type and shouldn't be part of our public API

View File

@ -4,30 +4,26 @@ package net.corda.testing.node.internal.network
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.internal.CertRole import net.corda.core.internal.CertRole
import net.corda.core.internal.toX500Name
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.days import net.corda.core.utilities.days
import net.corda.core.utilities.minutes import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA import net.corda.coretesting.internal.DEV_INTERMEDIATE_CA
import net.corda.coretesting.internal.DEV_ROOT_CA import net.corda.coretesting.internal.DEV_ROOT_CA
import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair import net.corda.nodeapi.internal.crypto.CertificateAndKeyPair
import net.corda.nodeapi.internal.crypto.ContentSignerBuilder import net.corda.nodeapi.internal.crypto.ContentSignerBuilder
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.X509Utilities.toGeneralNames
import net.corda.nodeapi.internal.crypto.certificateType import net.corda.nodeapi.internal.crypto.certificateType
import net.corda.nodeapi.internal.crypto.toJca import net.corda.nodeapi.internal.crypto.toJca
import org.bouncycastle.asn1.x500.X500Name import net.corda.testing.core.createCRL
import org.bouncycastle.asn1.x509.CRLDistPoint import org.bouncycastle.asn1.x509.CRLDistPoint
import org.bouncycastle.asn1.x509.DistributionPoint import org.bouncycastle.asn1.x509.DistributionPoint
import org.bouncycastle.asn1.x509.DistributionPointName import org.bouncycastle.asn1.x509.DistributionPointName
import org.bouncycastle.asn1.x509.Extension import org.bouncycastle.asn1.x509.Extension
import org.bouncycastle.asn1.x509.GeneralName import org.bouncycastle.asn1.x509.GeneralName
import org.bouncycastle.asn1.x509.GeneralNames import org.bouncycastle.asn1.x509.GeneralNames
import org.bouncycastle.asn1.x509.IssuingDistributionPoint
import org.bouncycastle.asn1.x509.ReasonFlags
import org.bouncycastle.cert.jcajce.JcaX509CRLConverter
import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.ServerConnector import org.eclipse.jetty.server.ServerConnector
import org.eclipse.jetty.server.handler.HandlerCollection import org.eclipse.jetty.server.handler.HandlerCollection
@ -36,11 +32,12 @@ import org.eclipse.jetty.servlet.ServletHolder
import org.glassfish.jersey.server.ResourceConfig import org.glassfish.jersey.server.ResourceConfig
import org.glassfish.jersey.servlet.ServletContainer import org.glassfish.jersey.servlet.ServletContainer
import java.io.Closeable import java.io.Closeable
import java.math.BigInteger
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.net.URI
import java.security.KeyPair import java.security.KeyPair
import java.security.cert.X509CRL import java.security.cert.X509CRL
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.time.Duration
import java.util.* import java.util.*
import javax.security.auth.x500.X500Principal import javax.security.auth.x500.X500Principal
import javax.ws.rs.GET import javax.ws.rs.GET
@ -51,7 +48,7 @@ import kotlin.collections.ArrayList
class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
companion object { companion object {
private const val SIGNATURE_ALGORITHM = "SHA256withECDSA" private val logger = contextLogger()
const val NODE_CRL = "node.crl" const val NODE_CRL = "node.crl"
const val FORBIDDEN_CRL = "forbidden.crl" const val FORBIDDEN_CRL = "forbidden.crl"
@ -72,8 +69,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
null null
) )
if (crlDistPoint != null) { if (crlDistPoint != null) {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, crlDistPoint))) val distPointName = DistributionPointName(toGeneralNames(crlDistPoint, GeneralName.uniformResourceIdentifier))
val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(X500Name.getInstance(it.encoded))) } val crlIssuerGeneralNames = crlIssuer?.let { GeneralNames(GeneralName(it.toX500Name())) }
val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames) val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames)
builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint))) builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint)))
} }
@ -87,14 +84,17 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
} }
} }
val revokedNodeCerts: MutableList<BigInteger> = ArrayList() val revokedNodeCerts: MutableList<X509Certificate> = ArrayList()
val revokedIntermediateCerts: MutableList<BigInteger> = ArrayList() val revokedIntermediateCerts: MutableList<X509Certificate> = ArrayList()
val rootCa: CertificateAndKeyPair = DEV_ROOT_CA val rootCa: CertificateAndKeyPair = DEV_ROOT_CA
private lateinit var _intermediateCa: CertificateAndKeyPair private lateinit var _intermediateCa: CertificateAndKeyPair
val intermediateCa: CertificateAndKeyPair get() = _intermediateCa val intermediateCa: CertificateAndKeyPair get() = _intermediateCa
@Volatile
var delay: Duration? = null
val hostAndPort: NetworkHostAndPort val hostAndPort: NetworkHostAndPort
get() = server.connectors.mapNotNull { it as? ServerConnector } get() = server.connectors.mapNotNull { it as? ServerConnector }
.map { NetworkHostAndPort(it.host, it.localPort) } .map { NetworkHostAndPort(it.host, it.localPort) }
@ -106,7 +106,7 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
DEV_INTERMEDIATE_CA.certificate.withCrlDistPoint(rootCa.keyPair, "http://$hostAndPort/crl/$INTERMEDIATE_CRL"), DEV_INTERMEDIATE_CA.certificate.withCrlDistPoint(rootCa.keyPair, "http://$hostAndPort/crl/$INTERMEDIATE_CRL"),
DEV_INTERMEDIATE_CA.keyPair DEV_INTERMEDIATE_CA.keyPair
) )
println("Network management web services started on $hostAndPort") logger.info("Network management web services started on $hostAndPort")
} }
fun replaceNodeCertDistPoint(nodeCaCert: X509Certificate, fun replaceNodeCertDistPoint(nodeCaCert: X509Certificate,
@ -115,29 +115,20 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
return nodeCaCert.withCrlDistPoint(intermediateCa.keyPair, nodeCaCrlDistPoint, crlIssuer) return nodeCaCert.withCrlDistPoint(intermediateCa.keyPair, nodeCaCrlDistPoint, crlIssuer)
} }
fun createRevocationList(signatureAlgorithm: String, private fun createServerCRL(issuer: CertificateAndKeyPair,
ca: CertificateAndKeyPair, endpoint: String,
endpoint: String, indirect: Boolean,
indirect: Boolean, revokedCerts: List<X509Certificate>): X509CRL {
serialNumbers: List<BigInteger>): X509CRL { logger.info("Generating CRL for /$endpoint: ${revokedCerts.map { it.serialNumber }}")
println("Generating CRL for $endpoint") return createCRL(
val builder = JcaX509v2CRLBuilder(ca.certificate.subjectX500Principal, Date(System.currentTimeMillis() - 1.minutes.toMillis())) issuer,
val extensionUtils = JcaX509ExtensionUtils() revokedCerts,
builder.addExtension(Extension.authorityKeyIdentifier, false, extensionUtils.createAuthorityKeyIdentifier(ca.certificate)) issuingDistPoint = URI("http://$hostAndPort/crl/$endpoint"),
val issuingDistPointName = GeneralName(GeneralName.uniformResourceIdentifier, "http://$hostAndPort/crl/$endpoint") indirect = indirect
// This is required and needs to match the certificate settings with respect to being indirect )
val issuingDistPoint = IssuingDistributionPoint(DistributionPointName(GeneralNames(issuingDistPointName)), indirect, false)
builder.addExtension(Extension.issuingDistributionPoint, true, issuingDistPoint)
builder.setNextUpdate(Date(System.currentTimeMillis() + 1.seconds.toMillis()))
serialNumbers.forEach {
builder.addCRLEntry(it, Date(System.currentTimeMillis() - 10.minutes.toMillis()), ReasonFlags.certificateHold)
}
val signer = JcaContentSignerBuilder(signatureAlgorithm).setProvider(Crypto.findProvider("BC")).build(ca.keyPair.private)
return JcaX509CRLConverter().setProvider(Crypto.findProvider("BC")).getCRL(builder.build(signer))
} }
override fun close() { override fun close() {
println("Shutting down network management web services...")
server.stop() server.stop()
server.join() server.join()
} }
@ -159,8 +150,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
@Path(NODE_CRL) @Path(NODE_CRL)
@Produces("application/pkcs7-crl") @Produces("application/pkcs7-crl")
fun getNodeCRL(): Response { fun getNodeCRL(): Response {
return Response.ok(crlServer.createRevocationList( crlServer.delay?.toMillis()?.let(Thread::sleep)
SIGNATURE_ALGORITHM, return Response.ok(crlServer.createServerCRL(
crlServer.intermediateCa, crlServer.intermediateCa,
NODE_CRL, NODE_CRL,
false, false,
@ -179,8 +170,8 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
@Path(INTERMEDIATE_CRL) @Path(INTERMEDIATE_CRL)
@Produces("application/pkcs7-crl") @Produces("application/pkcs7-crl")
fun getIntermediateCRL(): Response { fun getIntermediateCRL(): Response {
return Response.ok(crlServer.createRevocationList( crlServer.delay?.toMillis()?.let(Thread::sleep)
SIGNATURE_ALGORITHM, return Response.ok(crlServer.createServerCRL(
crlServer.rootCa, crlServer.rootCa,
INTERMEDIATE_CRL, INTERMEDIATE_CRL,
false, false,
@ -192,11 +183,11 @@ class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
@Path(EMPTY_CRL) @Path(EMPTY_CRL)
@Produces("application/pkcs7-crl") @Produces("application/pkcs7-crl")
fun getEmptyCRL(): Response { fun getEmptyCRL(): Response {
return Response.ok(crlServer.createRevocationList( return Response.ok(crlServer.createServerCRL(
SIGNATURE_ALGORITHM,
crlServer.rootCa, crlServer.rootCa,
EMPTY_CRL, EMPTY_CRL,
true, emptyList() true,
emptyList()
).encoded).build() ).encoded).build()
} }
} }

View File

@ -42,6 +42,7 @@ import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.SchemaMigration import net.corda.nodeapi.internal.persistence.SchemaMigration
import net.corda.nodeapi.internal.protonwrapper.netty.CrlSource
import net.corda.nodeapi.internal.registerDevP2pCertificates import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.serialization.internal.amqp.AMQP_ENABLED import net.corda.serialization.internal.amqp.AMQP_ENABLED
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
@ -52,6 +53,8 @@ import java.io.IOException
import java.net.ServerSocket import java.net.ServerSocket
import java.nio.file.Path import java.nio.file.Path
import java.security.KeyPair import java.security.KeyPair
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.util.* import java.util.*
import java.util.jar.JarOutputStream import java.util.jar.JarOutputStream
import java.util.jar.Manifest import java.util.jar.Manifest
@ -147,6 +150,12 @@ fun p2pSslOptions(path: Path, name: CordaX500Name = CordaX500Name("MegaCorp", "L
return sslConfig return sslConfig
} }
fun fixedCrlSource(crls: Set<X509CRL>): CrlSource {
return object : CrlSource {
override fun fetch(certificate: X509Certificate): Set<X509CRL> = crls
}
}
/** This is the same as the deprecated [WireTransaction] c'tor but avoids the deprecation warning. */ /** This is the same as the deprecated [WireTransaction] c'tor but avoids the deprecation warning. */
@SuppressWarnings("LongParameterList") @SuppressWarnings("LongParameterList")
fun createWireTransaction(inputs: List<StateRef>, fun createWireTransaction(inputs: List<StateRef>,

View File

@ -1,6 +1,20 @@
@file:Suppress("LongParameterList")
package net.corda.testing.internal.vault package net.corda.testing.internal.vault
import net.corda.core.contracts.* import net.corda.core.contracts.Amount
import net.corda.core.contracts.AttachmentConstraint
import net.corda.core.contracts.AutomaticPlaceholderConstraint
import net.corda.core.contracts.BelongsToContract
import net.corda.core.contracts.CommandAndState
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.FungibleAsset
import net.corda.core.contracts.Issued
import net.corda.core.contracts.LinearState
import net.corda.core.contracts.PartyAndReference
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.TransactionState
import net.corda.core.contracts.UniqueIdentifier
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SignatureMetadata import net.corda.core.crypto.SignatureMetadata
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
@ -19,9 +33,7 @@ import net.corda.finance.contracts.asset.Cash
import net.corda.finance.contracts.asset.Obligation import net.corda.finance.contracts.asset.Obligation
import net.corda.finance.contracts.asset.OnLedgerAsset import net.corda.finance.contracts.asset.OnLedgerAsset
import net.corda.finance.workflows.asset.CashUtils import net.corda.finance.workflows.asset.CashUtils
import net.corda.testing.contracts.DummyContract
import net.corda.testing.contracts.DummyState import net.corda.testing.contracts.DummyState
import net.corda.testing.core.DummyCommandData
import net.corda.testing.core.TestIdentity import net.corda.testing.core.TestIdentity
import net.corda.testing.core.dummyCommand import net.corda.testing.core.dummyCommand
import net.corda.testing.core.singleIdentity import net.corda.testing.core.singleIdentity
@ -32,6 +44,7 @@ import java.time.Duration
import java.time.Instant import java.time.Instant
import java.time.Instant.now import java.time.Instant.now
import java.util.* import java.util.*
import kotlin.math.floor
/** /**
* The service hub should provide at least a key management service and a storage service. * The service hub should provide at least a key management service and a storage service.
@ -46,7 +59,7 @@ class VaultFiller @JvmOverloads constructor(
private val rngFactory: () -> Random = { Random(0L) }) { private val rngFactory: () -> Random = { Random(0L) }) {
companion object { companion object {
fun calculateRandomlySizedAmounts(howMuch: Amount<Currency>, min: Int, max: Int, rng: Random): LongArray { fun calculateRandomlySizedAmounts(howMuch: Amount<Currency>, min: Int, max: Int, rng: Random): LongArray {
val numSlots = min + Math.floor(rng.nextDouble() * (max - min)).toInt() val numSlots = min + floor(rng.nextDouble() * (max - min)).toInt()
val baseSize = howMuch.quantity / numSlots val baseSize = howMuch.quantity / numSlots
check(baseSize > 0) { baseSize } check(baseSize > 0) { baseSize }
@ -79,31 +92,18 @@ class VaultFiller @JvmOverloads constructor(
issuerServices: ServiceHub = services, issuerServices: ServiceHub = services,
participants: List<AbstractParty> = emptyList(), participants: List<AbstractParty> = emptyList(),
includeMe: Boolean = true): Vault<DealState> { includeMe: Boolean = true): Vault<DealState> {
val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey return fillWithTestStates(
val me = AnonymousParty(myKey) txCount = dealIds.size,
val participantsToUse = if (includeMe) participants.plus(me) else participants participants = participants,
includeMe = includeMe,
val transactions: List<SignedTransaction> = dealIds.map { services = issuerServices
// Issue a deal state ) { participantsToUse, txIndex, _ ->
val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { DummyDealContract.State(ref = dealIds[txIndex], participants = participantsToUse)
addOutputState(DummyDealContract.State(ref = it, participants = participantsToUse), DUMMY_DEAL_PROGRAM_ID)
addCommand(dummyCommand())
}
val stx = issuerServices.signInitialTransaction(dummyIssue)
return@map services.addSignature(stx, defaultNotary.publicKey)
} }
val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE
services.recordTransactions(statesToRecord, transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<DealState>(i) }
}
return Vault(states)
} }
@JvmOverloads @JvmOverloads
fun fillWithSomeTestLinearStates(numberToCreate: Int, fun fillWithSomeTestLinearStates(txCount: Int,
externalId: String? = null, externalId: String? = null,
participants: List<AbstractParty> = emptyList(), participants: List<AbstractParty> = emptyList(),
uniqueIdentifier: UniqueIdentifier? = null, uniqueIdentifier: UniqueIdentifier? = null,
@ -113,81 +113,41 @@ class VaultFiller @JvmOverloads constructor(
linearTimestamp: Instant = now(), linearTimestamp: Instant = now(),
constraint: AttachmentConstraint = AutomaticPlaceholderConstraint, constraint: AttachmentConstraint = AutomaticPlaceholderConstraint,
includeMe: Boolean = true): Vault<LinearState> { includeMe: Boolean = true): Vault<LinearState> {
val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey return fillWithTestStates(txCount, 1, participants, constraint, includeMe) { participantsToUse, _, _ ->
val me = AnonymousParty(myKey) DummyLinearContract.State(
val issuerKey = defaultNotary.keyPair linearId = uniqueIdentifier ?: UniqueIdentifier(externalId),
val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) participants = participantsToUse,
val participantsToUse = if (includeMe) participants.plus(me) else participants linearString = linearString,
val transactions: List<SignedTransaction> = (1..numberToCreate).map { linearNumber = linearNumber,
// Issue a Linear state linearBoolean = linearBoolean,
val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply { linearTimestamp = linearTimestamp
addOutputState(DummyLinearContract.State( )
linearId = uniqueIdentifier ?: UniqueIdentifier(externalId),
participants = participantsToUse,
linearString = linearString,
linearNumber = linearNumber,
linearBoolean = linearBoolean,
linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID,
constraint = constraint)
addCommand(dummyCommand())
}
return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata)
} }
val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE
services.recordTransactions(statesToRecord, transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<LinearState>(i) }
}
return Vault(states)
} }
@JvmOverloads @JvmOverloads
fun fillWithSomeTestLinearAndDealStates(numberToCreate: Int, fun fillWithSomeTestLinearAndDealStates(txCount: Int,
externalId: String? = null, externalId: String? = null,
participants: List<AbstractParty> = emptyList(), participants: List<AbstractParty> = emptyList(),
linearString: String = "", linearString: String = "",
linearNumber: Long = 0L, linearNumber: Long = 0L,
linearBoolean: Boolean = false, linearBoolean: Boolean = false,
linearTimestamp: Instant = now()): Vault<LinearState> { linearTimestamp: Instant = now()): Vault<ContractState> {
val myKey: PublicKey = services.myInfo.chooseIdentity().owningKey return fillWithTestStates(txCount, 2, participants) { participantsToUse, _, stateIndex ->
val me = AnonymousParty(myKey) when (stateIndex) {
val issuerKey = defaultNotary.keyPair 0 -> DummyLinearContract.State(
val signatureMetadata = SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID)
val transactions: List<SignedTransaction> = (1..numberToCreate).map {
val dummyIssue = TransactionBuilder(notary = defaultNotary.party).apply {
// Issue a Linear state
addOutputState(DummyLinearContract.State(
linearId = UniqueIdentifier(externalId), linearId = UniqueIdentifier(externalId),
participants = participants.plus(me), participants = participantsToUse,
linearString = linearString, linearString = linearString,
linearNumber = linearNumber, linearNumber = linearNumber,
linearBoolean = linearBoolean, linearBoolean = linearBoolean,
linearTimestamp = linearTimestamp), DUMMY_LINEAR_CONTRACT_PROGRAM_ID) linearTimestamp = linearTimestamp
// Issue a Deal state )
addOutputState(DummyDealContract.State(ref = "test ref", participants = participants.plus(me)), DUMMY_DEAL_PROGRAM_ID) else -> DummyDealContract.State(ref = "test ref", participants = participantsToUse)
addCommand(dummyCommand())
} }
return@map services.signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata)
} }
services.recordTransactions(transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<LinearState>(i) }
}
return Vault(states)
} }
@JvmOverloads
fun fillWithSomeTestCash(howMuch: Amount<Currency>,
issuerServices: ServiceHub,
thisManyStates: Int,
issuedBy: PartyAndReference,
owner: AbstractParty? = null,
rng: Random? = null,
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT) = fillWithSomeTestCash(howMuch, issuerServices, thisManyStates, thisManyStates, issuedBy, owner, rng, statesToRecord)
/** /**
* Creates a random set of between (by default) 3 and 10 cash states that add up to the given amount and adds them * Creates a random set of between (by default) 3 and 10 cash states that add up to the given amount and adds them
* to the vault. This is intended for unit tests. By default the cash is owned by the legal * to the vault. This is intended for unit tests. By default the cash is owned by the legal
@ -196,14 +156,15 @@ class VaultFiller @JvmOverloads constructor(
* @param issuerServices service hub of the issuer node, which will be used to sign the transaction. * @param issuerServices service hub of the issuer node, which will be used to sign the transaction.
* @return a vault object that represents the generated states (it will NOT be the full vault from the service hub!). * @return a vault object that represents the generated states (it will NOT be the full vault from the service hub!).
*/ */
@JvmOverloads
fun fillWithSomeTestCash(howMuch: Amount<Currency>, fun fillWithSomeTestCash(howMuch: Amount<Currency>,
issuerServices: ServiceHub, issuerServices: ServiceHub,
atLeastThisManyStates: Int, atLeastThisManyStates: Int,
atMostThisManyStates: Int,
issuedBy: PartyAndReference, issuedBy: PartyAndReference,
owner: AbstractParty? = null, owner: AbstractParty? = null,
rng: Random? = null, rng: Random? = null,
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault<Cash.State> { statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT,
atMostThisManyStates: Int = atLeastThisManyStates): Vault<Cash.State> {
val amounts = calculateRandomlySizedAmounts(howMuch, atLeastThisManyStates, atMostThisManyStates, rng ?: rngFactory()) val amounts = calculateRandomlySizedAmounts(howMuch, atLeastThisManyStates, atMostThisManyStates, rng ?: rngFactory())
// We will allocate one state to one transaction, for simplicities sake. // We will allocate one state to one transaction, for simplicities sake.
val cash = Cash() val cash = Cash()
@ -212,39 +173,46 @@ class VaultFiller @JvmOverloads constructor(
cash.generateIssue(issuance, Amount(pennies, Issued(issuedBy, howMuch.token)), owner ?: services.myInfo.singleIdentity(), altNotary) cash.generateIssue(issuance, Amount(pennies, Issued(issuedBy, howMuch.token)), owner ?: services.myInfo.singleIdentity(), altNotary)
return@map issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) return@map issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey)
} }
services.recordTransactions(statesToRecord, transactions) return recordTransactions(transactions, statesToRecord)
// Get all the StateRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<Cash.State>(i) }
}
return Vault(states)
} }
/** /**
* Records a dummy state in the Vault (useful for creating random states when testing vault queries) * Records a dummy state in the Vault (useful for creating random states when testing vault queries)
*/ */
fun fillWithDummyState(participants: List<AbstractParty> = listOf(services.myInfo.singleIdentity())) : Vault<DummyState> { fun fillWithDummyState(participants: List<AbstractParty> = listOf(services.myInfo.singleIdentity())): Vault<DummyState> {
val outputState = TransactionState( return fillWithTestStates(participants = participants) { participantsToUse, _, _ ->
data = DummyState(Random().nextInt(), participants = participants), DummyState(Random().nextInt(), participants = participantsToUse)
contract = DummyContract.PROGRAM_ID, }
notary = defaultNotary.party
)
val participantKeys : List<PublicKey> = participants.map { it.owningKey }
val builder = TransactionBuilder()
.addOutputState(outputState)
.addCommand(DummyCommandData, participantKeys)
val stxn = services.signInitialTransaction(builder)
services.recordTransactions(stxn)
return Vault(setOf(stxn.tx.outRef(0)))
} }
/** fun <T : ContractState> fillWithTestStates(txCount: Int = 1,
* Puts together an issuance transaction for the specified amount that starts out being owned by the given pubkey. statesPerTx: Int = 1,
*/ participants: List<AbstractParty> = emptyList(),
fun generateCommoditiesIssue(tx: TransactionBuilder, amount: Amount<Issued<Commodity>>, owner: AbstractParty, notary: Party) constraint: AttachmentConstraint = AutomaticPlaceholderConstraint,
= OnLedgerAsset.generateIssue(tx, TransactionState(CommodityState(amount, owner), Obligation.PROGRAM_ID, notary), Obligation.Commands.Issue()) includeMe: Boolean = true,
services: ServiceHub = this.services,
genOutputState: (participantsToUse: List<AbstractParty>, txIndex: Int, stateIndex: Int) -> T): Vault<T> {
val issuerKey = defaultNotary.keyPair
val signatureMetadata = SignatureMetadata(
services.myInfo.platformVersion,
Crypto.findSignatureScheme(issuerKey.public).schemeNumberID
)
val participantsToUse = if (includeMe) {
participants + AnonymousParty(this.services.myInfo.chooseIdentity().owningKey)
} else {
participants
}
val transactions = Array(txCount) { txIndex ->
val builder = TransactionBuilder(notary = defaultNotary.party)
repeat(statesPerTx) { stateIndex ->
builder.addOutputState(genOutputState(participantsToUse, txIndex, stateIndex), constraint)
}
builder.addCommand(dummyCommand())
services.signInitialTransaction(builder).withAdditionalSignature(issuerKey, signatureMetadata)
}
val statesToRecord = if (includeMe) StatesToRecord.ONLY_RELEVANT else StatesToRecord.ALL_VISIBLE
return recordTransactions(transactions.asList(), statesToRecord)
}
/** /**
* *
@ -257,13 +225,16 @@ class VaultFiller @JvmOverloads constructor(
val me = AnonymousParty(myKey) val me = AnonymousParty(myKey)
val issuance = TransactionBuilder(null as Party?) val issuance = TransactionBuilder(null as Party?)
generateCommoditiesIssue(issuance, Amount(amount.quantity, Issued(issuedBy, amount.token)), me, altNotary) OnLedgerAsset.generateIssue(
issuance,
TransactionState(CommodityState(Amount(amount.quantity, Issued(issuedBy, amount.token)), me), Obligation.PROGRAM_ID, altNotary),
Obligation.Commands.Issue()
)
val transaction = issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) val transaction = issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey)
services.recordTransactions(transaction) return recordTransactions(listOf(transaction))
return Vault(setOf(transaction.tx.outRef(0)))
} }
private fun <T : LinearState> consume(states: List<StateAndRef<T>>) { fun consumeStates(states: Iterable<StateAndRef<*>>) {
// Create a txn consuming different contract types // Create a txn consuming different contract types
states.forEach { states.forEach {
val builder = TransactionBuilder(notary = altNotary).apply { val builder = TransactionBuilder(notary = altNotary).apply {
@ -300,10 +271,11 @@ class VaultFiller @JvmOverloads constructor(
} }
} }
fun consumeDeals(dealStates: List<StateAndRef<DealState>>) = consume(dealStates) fun consumeDeals(dealStates: List<StateAndRef<DealState>>) = consumeStates(dealStates)
fun consumeLinearStates(linearStates: List<StateAndRef<LinearState>>) = consume(linearStates) fun consumeLinearStates(linearStates: List<StateAndRef<LinearState>>) = consumeStates(linearStates)
fun evolveLinearStates(linearStates: List<StateAndRef<LinearState>>) = consumeAndProduce(linearStates) fun evolveLinearStates(linearStates: List<StateAndRef<LinearState>>) = consumeAndProduce(linearStates)
fun evolveLinearState(linearState: StateAndRef<LinearState>): StateAndRef<LinearState> = consumeAndProduce(linearState) fun evolveLinearState(linearState: StateAndRef<LinearState>): StateAndRef<LinearState> = consumeAndProduce(linearState)
/** /**
* Consume cash, sending any change to the default identity for this node. Only suitable for use in test scenarios, * Consume cash, sending any change to the default identity for this node. Only suitable for use in test scenarios,
* where nodes have a default identity. * where nodes have a default identity.
@ -319,6 +291,16 @@ class VaultFiller @JvmOverloads constructor(
services.recordTransactions(spendTx) services.recordTransactions(spendTx)
return update.getOrThrow(Duration.ofSeconds(3)) return update.getOrThrow(Duration.ofSeconds(3))
} }
private fun <T : ContractState> recordTransactions(transactions: Iterable<SignedTransaction>,
statesToRecord: StatesToRecord = StatesToRecord.ONLY_RELEVANT): Vault<T> {
services.recordTransactions(statesToRecord, transactions)
// Get all the StateAndRefs of all the generated transactions.
val states = transactions.flatMap { stx ->
stx.tx.outputs.indices.map { i -> stx.tx.outRef<T>(i) }
}
return Vault(states)
}
} }
@ -344,4 +326,3 @@ data class CommodityState(
override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Obligation.Commands.Move(), copy(owner = newOwner)) override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Obligation.Commands.Move(), copy(owner = newOwner))
} }