mirror of
https://github.com/corda/corda.git
synced 2025-06-12 20:28:18 +00:00
Merge branch 'release/os/4.5' into jamesh/error-reporting-sync-29-04-20
# Conflicts: # node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
This commit is contained in:
@ -4,6 +4,9 @@ import net.corda.core.serialization.internal.nodeSerializationEnv
|
||||
import net.corda.core.utilities.NetworkHostAndPort
|
||||
import net.corda.core.utilities.loggerFor
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_P2P_USER
|
||||
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pConnectorTcpTransport
|
||||
import net.corda.nodeapi.internal.ArtemisTcpTransport.Companion.p2pConnectorTcpTransportFromList
|
||||
import net.corda.nodeapi.internal.config.MessagingServerConnectionConfiguration
|
||||
import net.corda.nodeapi.internal.config.MutualSslConfiguration
|
||||
import org.apache.activemq.artemis.api.core.client.*
|
||||
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
|
||||
@ -17,28 +20,55 @@ interface ArtemisSessionProvider {
|
||||
class ArtemisMessagingClient(private val config: MutualSslConfiguration,
|
||||
private val serverAddress: NetworkHostAndPort,
|
||||
private val maxMessageSize: Int,
|
||||
private val failoverCallback: ((FailoverEventType) -> Unit)? = null) : ArtemisSessionProvider {
|
||||
private val autoCommitSends: Boolean = true,
|
||||
private val autoCommitAcks: Boolean = true,
|
||||
private val confirmationWindowSize: Int = -1,
|
||||
private val messagingServerConnectionConfig: MessagingServerConnectionConfiguration? = null,
|
||||
private val backupServerAddressPool: List<NetworkHostAndPort> = emptyList(),
|
||||
private val failoverCallback: ((FailoverEventType) -> Unit)? = null
|
||||
) : ArtemisSessionProvider {
|
||||
companion object {
|
||||
private val log = loggerFor<ArtemisMessagingClient>()
|
||||
const val CORDA_ARTEMIS_CALL_TIMEOUT_PROP_NAME = "net.corda.nodeapi.artemismessagingclient.CallTimeout"
|
||||
const val CORDA_ARTEMIS_CALL_TIMEOUT_DEFAULT = 5000L
|
||||
}
|
||||
|
||||
class Started(val sessionFactory: ClientSessionFactory, val session: ClientSession, val producer: ClientProducer)
|
||||
class Started(val serverLocator: ServerLocator, val sessionFactory: ClientSessionFactory, val session: ClientSession, val producer: ClientProducer)
|
||||
|
||||
override var started: Started? = null
|
||||
private set
|
||||
|
||||
override fun start(): Started = synchronized(this) {
|
||||
check(started == null) { "start can't be called twice" }
|
||||
val tcpTransport = p2pConnectorTcpTransport(serverAddress, config)
|
||||
val backupTransports = p2pConnectorTcpTransportFromList(backupServerAddressPool, config)
|
||||
|
||||
log.info("Connecting to message broker: $serverAddress")
|
||||
// TODO Add broker CN to config for host verification in case the embedded broker isn't used
|
||||
val tcpTransport = ArtemisTcpTransport.p2pConnectorTcpTransport(serverAddress, config)
|
||||
val locator = ActiveMQClient.createServerLocatorWithoutHA(tcpTransport).apply {
|
||||
if (backupTransports.isNotEmpty()) {
|
||||
log.info("Back-up message broker addresses: $backupServerAddressPool")
|
||||
}
|
||||
// If back-up artemis addresses are configured, the locator will be created using HA mode.
|
||||
@Suppress("SpreadOperator")
|
||||
val locator = ActiveMQClient.createServerLocator(backupTransports.isNotEmpty(), *(listOf(tcpTransport) + backupTransports).toTypedArray()).apply {
|
||||
// Never time out on our loopback Artemis connections. If we switch back to using the InVM transport this
|
||||
// would be the default and the two lines below can be deleted.
|
||||
connectionTTL = 60000
|
||||
clientFailureCheckPeriod = 30000
|
||||
callFailoverTimeout = java.lang.Long.getLong(CORDA_ARTEMIS_CALL_TIMEOUT_PROP_NAME, CORDA_ARTEMIS_CALL_TIMEOUT_DEFAULT)
|
||||
callTimeout = java.lang.Long.getLong(CORDA_ARTEMIS_CALL_TIMEOUT_PROP_NAME, CORDA_ARTEMIS_CALL_TIMEOUT_DEFAULT)
|
||||
minLargeMessageSize = maxMessageSize
|
||||
isUseGlobalPools = nodeSerializationEnv != null
|
||||
confirmationWindowSize = this@ArtemisMessagingClient.confirmationWindowSize
|
||||
producerWindowSize = -1
|
||||
messagingServerConnectionConfig?.let {
|
||||
connectionLoadBalancingPolicyClassName = RoundRobinConnectionPolicy::class.java.canonicalName
|
||||
reconnectAttempts = messagingServerConnectionConfig.reconnectAttempts(isHA)
|
||||
retryInterval = messagingServerConnectionConfig.retryInterval().toMillis()
|
||||
retryIntervalMultiplier = messagingServerConnectionConfig.retryIntervalMultiplier()
|
||||
maxRetryInterval = messagingServerConnectionConfig.maxRetryInterval(isHA).toMillis()
|
||||
isFailoverOnInitialConnection = messagingServerConnectionConfig.failoverOnInitialAttempt(isHA)
|
||||
initialConnectAttempts = messagingServerConnectionConfig.initialConnectAttempts(isHA)
|
||||
}
|
||||
addIncomingInterceptor(ArtemisMessageSizeChecksInterceptor(maxMessageSize))
|
||||
}
|
||||
val sessionFactory = locator.createSessionFactory()
|
||||
@ -50,23 +80,24 @@ class ArtemisMessagingClient(private val config: MutualSslConfiguration,
|
||||
// using our TLS certificate.
|
||||
// Note that the acknowledgement of messages is not flushed to the Artermis journal until the default buffer
|
||||
// size of 1MB is acknowledged.
|
||||
val session = sessionFactory!!.createSession(NODE_P2P_USER, NODE_P2P_USER, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
|
||||
val session = sessionFactory!!.createSession(NODE_P2P_USER, NODE_P2P_USER, false, autoCommitSends, autoCommitAcks, false, DEFAULT_ACK_BATCH_SIZE)
|
||||
session.start()
|
||||
// Create a general purpose producer.
|
||||
val producer = session.createProducer()
|
||||
return Started(sessionFactory, session, producer).also { started = it }
|
||||
return Started(locator, sessionFactory, session, producer).also { started = it }
|
||||
}
|
||||
|
||||
override fun stop() = synchronized(this) {
|
||||
started?.run {
|
||||
producer.close()
|
||||
// Since we are leaking the session outside of this class it may well be already closed.
|
||||
if(!session.isClosed) {
|
||||
if (session.stillOpen()) {
|
||||
// Ensure any trailing messages are committed to the journal
|
||||
session.commit()
|
||||
}
|
||||
// Closing the factory closes all the sessions it produced as well.
|
||||
sessionFactory.close()
|
||||
serverLocator.close()
|
||||
}
|
||||
started = null
|
||||
}
|
||||
|
@ -1,6 +1,5 @@
|
||||
package net.corda.nodeapi.internal
|
||||
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.toStringShort
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.messaging.MessageRecipientGroup
|
||||
@ -34,6 +33,7 @@ class ArtemisMessagingComponent {
|
||||
// This is a rough guess on the extra space needed on top of maxMessageSize to store the journal.
|
||||
// TODO: we might want to make this value configurable.
|
||||
const val JOURNAL_HEADER_SIZE = 1024
|
||||
|
||||
object P2PMessagingHeaders {
|
||||
// This is a "property" attached to an Artemis MQ message object, which contains our own notion of "topic".
|
||||
// We should probably try to unify our notion of "topic" (really, just a string that identifies an endpoint
|
||||
@ -123,6 +123,11 @@ class ArtemisMessagingComponent {
|
||||
require(address.startsWith(PEERS_PREFIX)) { "Failed to map address: $address to a remote topic as it is not in the $PEERS_PREFIX namespace" }
|
||||
return P2P_PREFIX + address.substring(PEERS_PREFIX.length)
|
||||
}
|
||||
|
||||
fun translateInboxAddressToLocalQueue(address: String): String {
|
||||
require(address.startsWith(P2P_PREFIX)) { "Failed to map topic: $address to a local address as it is not in the $P2P_PREFIX namespace" }
|
||||
return PEERS_PREFIX + address.substring(P2P_PREFIX.length)
|
||||
}
|
||||
}
|
||||
|
||||
override val queueName: String = "$P2P_PREFIX${identity.toStringShort()}"
|
||||
|
@ -100,35 +100,43 @@ class ArtemisTcpTransport {
|
||||
|
||||
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true): TransportConfiguration {
|
||||
|
||||
return p2pAcceptorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL)
|
||||
return p2pAcceptorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL, useOpenSsl = config?.useOpenSsl ?: false)
|
||||
}
|
||||
|
||||
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true): TransportConfiguration {
|
||||
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, config: MutualSslConfiguration?, enableSSL: Boolean = true, keyStoreProvider: String? = null): TransportConfiguration {
|
||||
|
||||
return p2pConnectorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL)
|
||||
return p2pConnectorTcpTransport(hostAndPort, config?.keyStore, config?.trustStore, enableSSL = enableSSL, useOpenSsl = config?.useOpenSsl ?: false, keyStoreProvider = keyStoreProvider)
|
||||
}
|
||||
|
||||
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true): TransportConfiguration {
|
||||
fun p2pAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true, useOpenSsl: Boolean = false): TransportConfiguration {
|
||||
|
||||
val options = defaultArtemisOptions(hostAndPort).toMutableMap()
|
||||
if (enableSSL) {
|
||||
options.putAll(defaultSSLOptions)
|
||||
(keyStore to trustStore).addToTransportOptions(options)
|
||||
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
|
||||
}
|
||||
options[TransportConstants.HANDSHAKE_TIMEOUT] = 0 // Suppress core.server.lambda$channelActive$0 - AMQ224088 error from load balancer type connections
|
||||
return TransportConfiguration(acceptorFactoryClassName, options)
|
||||
}
|
||||
|
||||
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true): TransportConfiguration {
|
||||
@Suppress("LongParameterList")
|
||||
fun p2pConnectorTcpTransport(hostAndPort: NetworkHostAndPort, keyStore: FileBasedCertificateStoreSupplier?, trustStore: FileBasedCertificateStoreSupplier?, enableSSL: Boolean = true, useOpenSsl: Boolean = false, keyStoreProvider: String? = null): TransportConfiguration {
|
||||
|
||||
val options = defaultArtemisOptions(hostAndPort).toMutableMap()
|
||||
if (enableSSL) {
|
||||
options.putAll(defaultSSLOptions)
|
||||
(keyStore to trustStore).addToTransportOptions(options)
|
||||
options[TransportConstants.SSL_PROVIDER] = if (useOpenSsl) TransportConstants.OPENSSL_PROVIDER else TransportConstants.DEFAULT_SSL_PROVIDER
|
||||
keyStoreProvider?.let { options.put(TransportConstants.KEYSTORE_PROVIDER_PROP_NAME, keyStoreProvider) }
|
||||
}
|
||||
return TransportConfiguration(connectorFactoryClassName, options)
|
||||
}
|
||||
|
||||
fun p2pConnectorTcpTransportFromList(hostAndPortList: List<NetworkHostAndPort>, config: MutualSslConfiguration?, enableSSL: Boolean = true, keyStoreProvider: String? = null): List<TransportConfiguration> = hostAndPortList.map {
|
||||
p2pConnectorTcpTransport(it, config, enableSSL, keyStoreProvider)
|
||||
}
|
||||
|
||||
fun rpcAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: BrokerRpcSslOptions?, enableSSL: Boolean = true): TransportConfiguration {
|
||||
val options = defaultArtemisOptions(hostAndPort).toMutableMap()
|
||||
|
||||
@ -156,12 +164,17 @@ class ArtemisTcpTransport {
|
||||
rpcConnectorTcpTransport(it, config, enableSSL)
|
||||
}
|
||||
|
||||
fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration): TransportConfiguration {
|
||||
return TransportConfiguration(connectorFactoryClassName, defaultArtemisOptions(hostAndPort) + defaultSSLOptions + config.toTransportOptions())
|
||||
fun rpcInternalClientTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration, keyStoreProvider: String? = null): TransportConfiguration {
|
||||
return TransportConfiguration(connectorFactoryClassName, defaultArtemisOptions(hostAndPort) + defaultSSLOptions + config.toTransportOptions() + asMap(keyStoreProvider))
|
||||
}
|
||||
|
||||
fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration): TransportConfiguration {
|
||||
return TransportConfiguration(acceptorFactoryClassName, defaultArtemisOptions(hostAndPort) + defaultSSLOptions + config.toTransportOptions() + (TransportConstants.HANDSHAKE_TIMEOUT to 0))
|
||||
fun rpcInternalAcceptorTcpTransport(hostAndPort: NetworkHostAndPort, config: SslConfiguration, keyStoreProvider: String? = null): TransportConfiguration {
|
||||
return TransportConfiguration(acceptorFactoryClassName, defaultArtemisOptions(hostAndPort) + defaultSSLOptions +
|
||||
config.toTransportOptions() + (TransportConstants.HANDSHAKE_TIMEOUT to 0) + asMap(keyStoreProvider))
|
||||
}
|
||||
|
||||
private fun asMap(keyStoreProvider: String?): Map<String, String> {
|
||||
return keyStoreProvider?.let {mutableMapOf(TransportConstants.KEYSTORE_PROVIDER_PROP_NAME to it)} ?: emptyMap()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,5 +1,4 @@
|
||||
@file:JvmName("ArtemisUtils")
|
||||
|
||||
package net.corda.nodeapi.internal
|
||||
|
||||
import java.nio.file.FileSystems
|
||||
@ -16,3 +15,4 @@ fun Path.requireOnDefaultFileSystem() {
|
||||
fun requireMessageSize(messageSize: Int, limit: Int) {
|
||||
require(messageSize <= limit) { "Message exceeds maxMessageSize network parameter, maxMessageSize: [$limit], message size: [$messageSize]" }
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,8 @@
|
||||
package net.corda.nodeapi.internal
|
||||
|
||||
import org.apache.activemq.artemis.api.core.client.ClientSession
|
||||
import org.apache.activemq.artemis.core.client.impl.ClientSessionInternal
|
||||
|
||||
fun ClientSession.stillOpen(): Boolean {
|
||||
return (!isClosed && (this as? ClientSessionInternal)?.isClosing != false)
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
package net.corda.nodeapi.internal
|
||||
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock
|
||||
import kotlin.concurrent.read
|
||||
import kotlin.concurrent.write
|
||||
|
||||
/**
|
||||
* A [ConcurrentBox] allows the implementation of track() with reduced contention. [concurrent] may be run from several
|
||||
* threads (which means it MUST be threadsafe!), while [exclusive] stops the world until the tracking has been set up.
|
||||
* Internally [ConcurrentBox] is implemented simply as a read-write lock.
|
||||
*/
|
||||
class ConcurrentBox<out T>(val content: T) {
|
||||
val lock = ReentrantReadWriteLock()
|
||||
|
||||
inline fun <R> concurrent(block: T.() -> R): R = lock.read { block(content) }
|
||||
inline fun <R> exclusive(block: T.() -> R): R = lock.write { block(content) }
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
package net.corda.nodeapi.internal
|
||||
|
||||
import org.apache.activemq.artemis.api.core.client.loadbalance.ConnectionLoadBalancingPolicy
|
||||
|
||||
/**
|
||||
* Implementation of an Artemis load balancing policy. It does round-robin always starting from the first position, whereas
|
||||
* the current [RoundRobinConnectionLoadBalancingPolicy] in Artemis picks the starting position randomly. This can lead to
|
||||
* attempting to connect to an inactive broker on the first attempt, which can cause start-up delays depending on what connection
|
||||
* settings are used.
|
||||
*/
|
||||
class RoundRobinConnectionPolicy : ConnectionLoadBalancingPolicy {
|
||||
private var pos = 0
|
||||
|
||||
override fun select(max: Int): Int {
|
||||
pos = if (pos >= max) 0 else pos
|
||||
return pos++
|
||||
}
|
||||
}
|
@ -1,5 +1,8 @@
|
||||
@file:Suppress("TooGenericExceptionCaught") // needs to catch and handle/rethrow *all* exceptions in many places
|
||||
package net.corda.nodeapi.internal.bridging
|
||||
|
||||
import com.google.common.util.concurrent.ThreadFactoryBuilder
|
||||
import io.netty.channel.EventLoop
|
||||
import io.netty.channel.EventLoopGroup
|
||||
import io.netty.channel.nio.NioEventLoopGroup
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
@ -11,11 +14,14 @@ import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_P2P_U
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.RemoteInboxAddress.Companion.translateLocalQueueToInboxAddress
|
||||
import net.corda.nodeapi.internal.ArtemisSessionProvider
|
||||
import net.corda.nodeapi.internal.ArtemisConstants.MESSAGE_ID_KEY
|
||||
import net.corda.nodeapi.internal.config.CertificateStore
|
||||
import net.corda.nodeapi.internal.config.MutualSslConfiguration
|
||||
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
|
||||
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
|
||||
import org.apache.activemq.artemis.api.core.client.ClientConsumer
|
||||
@ -23,6 +29,10 @@ import org.apache.activemq.artemis.api.core.client.ClientMessage
|
||||
import org.apache.activemq.artemis.api.core.client.ClientSession
|
||||
import org.slf4j.MDC
|
||||
import rx.Subscription
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.ScheduledExecutorService
|
||||
import java.util.concurrent.ScheduledFuture
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
import kotlin.concurrent.withLock
|
||||
|
||||
@ -34,33 +44,46 @@ import kotlin.concurrent.withLock
|
||||
* The Netty thread pool used by the AMQPBridges is also shared and managed by the AMQPBridgeManager.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
class AMQPBridgeManager(config: MutualSslConfiguration,
|
||||
maxMessageSize: Int,
|
||||
crlCheckSoftFail: Boolean,
|
||||
private val artemisMessageClientFactory: () -> ArtemisSessionProvider,
|
||||
private val bridgeMetricsService: BridgeMetricsService? = null) : BridgeManager {
|
||||
open class AMQPBridgeManager(keyStore: CertificateStore,
|
||||
trustStore: CertificateStore,
|
||||
useOpenSSL: Boolean,
|
||||
proxyConfig: ProxyConfig? = null,
|
||||
maxMessageSize: Int,
|
||||
revocationConfig: RevocationConfig,
|
||||
enableSNI: Boolean,
|
||||
private val artemisMessageClientFactory: () -> ArtemisSessionProvider,
|
||||
private val bridgeMetricsService: BridgeMetricsService? = null,
|
||||
trace: Boolean,
|
||||
sslHandshakeTimeout: Long?,
|
||||
private val bridgeConnectionTTLSeconds: Int) : BridgeManager {
|
||||
|
||||
private val lock = ReentrantLock()
|
||||
private val queueNamesToBridgesMap = mutableMapOf<String, MutableList<AMQPBridge>>()
|
||||
|
||||
private class AMQPConfigurationImpl private constructor(override val keyStore: CertificateStore,
|
||||
override val trustStore: CertificateStore,
|
||||
override val maxMessageSize: Int,
|
||||
override val crlCheckSoftFail: Boolean) : AMQPConfiguration {
|
||||
constructor(config: MutualSslConfiguration, maxMessageSize: Int, crlCheckSoftFail: Boolean) : this(config.keyStore.get(), config.trustStore.get(), maxMessageSize, crlCheckSoftFail)
|
||||
private class AMQPConfigurationImpl(override val keyStore: CertificateStore,
|
||||
override val trustStore: CertificateStore,
|
||||
override val proxyConfig: ProxyConfig?,
|
||||
override val maxMessageSize: Int,
|
||||
override val revocationConfig: RevocationConfig,
|
||||
override val useOpenSsl: Boolean,
|
||||
override val enableSNI: Boolean,
|
||||
override val sourceX500Name: String? = null,
|
||||
override val trace: Boolean,
|
||||
private val _sslHandshakeTimeout: Long?) : AMQPConfiguration {
|
||||
override val sslHandshakeTimeout: Long
|
||||
get() = _sslHandshakeTimeout ?: super.sslHandshakeTimeout
|
||||
}
|
||||
|
||||
private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(config, maxMessageSize, crlCheckSoftFail)
|
||||
private val amqpConfig: AMQPConfiguration = AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize, revocationConfig,useOpenSSL, enableSNI, trace = trace, _sslHandshakeTimeout = sslHandshakeTimeout)
|
||||
private var sharedEventLoopGroup: EventLoopGroup? = null
|
||||
private var artemis: ArtemisSessionProvider? = null
|
||||
|
||||
constructor(config: MutualSslConfiguration,
|
||||
p2pAddress: NetworkHostAndPort,
|
||||
maxMessageSize: Int,
|
||||
crlCheckSoftFail: Boolean) : this(config, maxMessageSize, crlCheckSoftFail, { ArtemisMessagingClient(config, p2pAddress, maxMessageSize) })
|
||||
|
||||
companion object {
|
||||
private const val NUM_BRIDGE_THREADS = 0 // Default sized pool
|
||||
|
||||
private const val CORDA_NUM_BRIDGE_THREADS_PROP_NAME = "net.corda.nodeapi.amqpbridgemanager.NumBridgeThreads"
|
||||
|
||||
private val NUM_BRIDGE_THREADS = Integer.getInteger(CORDA_NUM_BRIDGE_THREADS_PROP_NAME, 0) // Default 0 means Netty default sized pool
|
||||
private const val ARTEMIS_RETRY_BACKOFF = 5000L
|
||||
}
|
||||
|
||||
/**
|
||||
@ -71,13 +94,16 @@ class AMQPBridgeManager(config: MutualSslConfiguration,
|
||||
* If the delivery fails the session is rolled back to prevent loss of the message. This may cause duplicate delivery,
|
||||
* however Artemis and the remote Corda instanced will deduplicate these messages.
|
||||
*/
|
||||
private class AMQPBridge(val queueName: String,
|
||||
@Suppress("TooManyFunctions")
|
||||
private class AMQPBridge(val sourceX500Name: String,
|
||||
val queueName: String,
|
||||
val targets: List<NetworkHostAndPort>,
|
||||
val legalNames: Set<CordaX500Name>,
|
||||
private val amqpConfig: AMQPConfiguration,
|
||||
sharedEventGroup: EventLoopGroup,
|
||||
private val artemis: ArtemisSessionProvider,
|
||||
private val bridgeMetricsService: BridgeMetricsService?) {
|
||||
private val bridgeMetricsService: BridgeMetricsService?,
|
||||
private val bridgeConnectionTTLSeconds: Int) {
|
||||
companion object {
|
||||
private val log = contextLogger()
|
||||
}
|
||||
@ -86,6 +112,7 @@ class AMQPBridgeManager(config: MutualSslConfiguration,
|
||||
val oldMDC = MDC.getCopyOfContextMap() ?: emptyMap<String, String>()
|
||||
try {
|
||||
MDC.put("queueName", queueName)
|
||||
MDC.put("source", amqpConfig.sourceX500Name)
|
||||
MDC.put("targets", targets.joinToString(separator = ";") { it.toString() })
|
||||
MDC.put("legalNames", legalNames.joinToString(separator = ";") { it.toString() })
|
||||
MDC.put("maxMessageSize", amqpConfig.maxMessageSize.toString())
|
||||
@ -106,10 +133,80 @@ class AMQPBridgeManager(config: MutualSslConfiguration,
|
||||
private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) }
|
||||
|
||||
val amqpClient = AMQPClient(targets, legalNames, amqpConfig, sharedThreadPool = sharedEventGroup)
|
||||
private val lock = ReentrantLock() // lock to serialise session level access
|
||||
private var session: ClientSession? = null
|
||||
private var consumer: ClientConsumer? = null
|
||||
private var connectedSubscription: Subscription? = null
|
||||
@Volatile
|
||||
private var messagesReceived: Boolean = false
|
||||
private val eventLoop: EventLoop = sharedEventGroup.next()
|
||||
private var artemisState: ArtemisState = ArtemisState.STOPPED
|
||||
set(value) {
|
||||
logDebugWithMDC { "State change $field to $value" }
|
||||
field = value
|
||||
}
|
||||
@Suppress("MagicNumber")
|
||||
private var artemisHeartbeatPlusBackoff = TimeUnit.SECONDS.toMillis(90)
|
||||
private var amqpRestartEvent: ScheduledFuture<Unit>? = null
|
||||
private var scheduledExecutorService: ScheduledExecutorService
|
||||
= Executors.newSingleThreadScheduledExecutor(ThreadFactoryBuilder().setNameFormat("bridge-connection-reset-%d").build())
|
||||
|
||||
@Suppress("ClassNaming")
|
||||
private sealed class ArtemisState {
|
||||
object STARTING : ArtemisState()
|
||||
data class STARTED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
|
||||
|
||||
object CHECKING : ArtemisState()
|
||||
object RESTARTED : ArtemisState()
|
||||
object RECEIVING : ArtemisState()
|
||||
|
||||
object AMQP_STOPPED : ArtemisState()
|
||||
object AMQP_STARTING : ArtemisState()
|
||||
object AMQP_STARTED : ArtemisState()
|
||||
object AMQP_RESTARTED : ArtemisState()
|
||||
|
||||
object STOPPING : ArtemisState()
|
||||
object STOPPED : ArtemisState()
|
||||
data class STOPPED_AMQP_START_SCHEDULED(override val pending: ScheduledFuture<Unit>) : ArtemisState()
|
||||
|
||||
open val pending: ScheduledFuture<Unit>? = null
|
||||
|
||||
override fun toString(): String = javaClass.simpleName
|
||||
}
|
||||
|
||||
private fun artemis(inProgress: ArtemisState, block: (precedingState: ArtemisState) -> ArtemisState) {
|
||||
val runnable = {
|
||||
synchronized(artemis) {
|
||||
try {
|
||||
val precedingState = artemisState
|
||||
artemisState.pending?.cancel(false)
|
||||
artemisState = inProgress
|
||||
artemisState = block(precedingState)
|
||||
} catch (ex: Exception) {
|
||||
withMDC { log.error("Unexpected error in Artemis processing in state $artemisState.", ex) }
|
||||
}
|
||||
}
|
||||
}
|
||||
if (eventLoop.inEventLoop()) {
|
||||
runnable()
|
||||
} else {
|
||||
eventLoop.execute(runnable)
|
||||
}
|
||||
}
|
||||
|
||||
private fun scheduledArtemis(delay: Long, unit: TimeUnit, inProgress: ArtemisState, block: (precedingState: ArtemisState) -> ArtemisState): ScheduledFuture<Unit> {
|
||||
return eventLoop.schedule<Unit>({
|
||||
artemis(inProgress, block)
|
||||
}, delay, unit)
|
||||
}
|
||||
|
||||
private fun scheduledArtemisInExecutor(delay: Long, unit: TimeUnit, inProgress: ArtemisState, nextState: ArtemisState, block: () -> Unit): ScheduledFuture<Unit> {
|
||||
return scheduledExecutorService.schedule<Unit>({
|
||||
artemis(inProgress) {
|
||||
nextState
|
||||
}
|
||||
block()
|
||||
}, delay, unit)
|
||||
}
|
||||
|
||||
fun start() {
|
||||
logInfoWithMDC("Create new AMQP bridge")
|
||||
@ -119,55 +216,196 @@ class AMQPBridgeManager(config: MutualSslConfiguration,
|
||||
|
||||
fun stop() {
|
||||
logInfoWithMDC("Stopping AMQP bridge")
|
||||
lock.withLock {
|
||||
synchronized(artemis) {
|
||||
consumer?.close()
|
||||
consumer = null
|
||||
session?.stop()
|
||||
session = null
|
||||
}
|
||||
}
|
||||
amqpClient.stop()
|
||||
connectedSubscription?.unsubscribe()
|
||||
connectedSubscription = null
|
||||
}
|
||||
|
||||
private fun onSocketConnected(connected: Boolean) {
|
||||
lock.withLock {
|
||||
synchronized(artemis) {
|
||||
if (connected) {
|
||||
logInfoWithMDC("Bridge Connected")
|
||||
bridgeMetricsService?.bridgeConnected(targets, legalNames)
|
||||
val sessionFactory = artemis.started!!.sessionFactory
|
||||
val session = sessionFactory.createSession(NODE_P2P_USER, NODE_P2P_USER, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
|
||||
this.session = session
|
||||
val consumer = session.createConsumer(queueName)
|
||||
this.consumer = consumer
|
||||
consumer.setMessageHandler(this@AMQPBridge::clientArtemisMessageHandler)
|
||||
session.start()
|
||||
} else {
|
||||
logInfoWithMDC("Bridge Disconnected")
|
||||
bridgeMetricsService?.bridgeDisconnected(targets, legalNames)
|
||||
consumer?.close()
|
||||
consumer = null
|
||||
session?.stop()
|
||||
artemis(ArtemisState.STOPPING) {
|
||||
logInfoWithMDC("Stopping Artemis because stopping AMQP bridge")
|
||||
closeConsumer()
|
||||
consumer = null
|
||||
eventLoop.execute {
|
||||
artemis(ArtemisState.STOPPING) {
|
||||
stopSession()
|
||||
session = null
|
||||
ArtemisState.STOPPED
|
||||
}
|
||||
}
|
||||
ArtemisState.STOPPING
|
||||
}
|
||||
bridgeMetricsService?.bridgeDisconnected(targets, legalNames)
|
||||
connectedSubscription?.unsubscribe()
|
||||
connectedSubscription = null
|
||||
// Do this last because we already scheduled the Artemis stop, so it's okay to unsubscribe onConnected first.
|
||||
amqpClient.stop()
|
||||
}
|
||||
|
||||
@Suppress("ComplexMethod")
|
||||
private fun onSocketConnected(connected: Boolean) {
|
||||
if (connected) {
|
||||
logInfoWithMDC("Bridge Connected")
|
||||
|
||||
bridgeMetricsService?.bridgeConnected(targets, legalNames)
|
||||
if (bridgeConnectionTTLSeconds > 0) {
|
||||
// AMQP outbound connection will be restarted periodically with bridgeConnectionTTLSeconds interval
|
||||
amqpRestartEvent = scheduledArtemisInExecutor(bridgeConnectionTTLSeconds.toLong(), TimeUnit.SECONDS,
|
||||
ArtemisState.AMQP_STOPPED, ArtemisState.AMQP_RESTARTED) {
|
||||
logInfoWithMDC("Bridge connection time to live exceeded. Restarting AMQP connection")
|
||||
stopAndStartOutbound(ArtemisState.AMQP_RESTARTED)
|
||||
}
|
||||
}
|
||||
artemis(ArtemisState.STARTING) {
|
||||
val startedArtemis = artemis.started
|
||||
if (startedArtemis == null) {
|
||||
logInfoWithMDC("Bridge Connected but Artemis is disconnected")
|
||||
ArtemisState.STOPPED
|
||||
} else {
|
||||
logInfoWithMDC("Bridge Connected so starting Artemis")
|
||||
artemisHeartbeatPlusBackoff = startedArtemis.serverLocator.connectionTTL + ARTEMIS_RETRY_BACKOFF
|
||||
try {
|
||||
createSessionAndConsumer(startedArtemis)
|
||||
ArtemisState.STARTED(scheduledArtemis(artemisHeartbeatPlusBackoff, TimeUnit.MILLISECONDS, ArtemisState.CHECKING) {
|
||||
if (!messagesReceived) {
|
||||
logInfoWithMDC("No messages received on new bridge. Restarting Artemis session")
|
||||
if (restartSession()) {
|
||||
ArtemisState.RESTARTED
|
||||
} else {
|
||||
logInfoWithMDC("Artemis session restart failed. Aborting by restarting AMQP connection.")
|
||||
stopAndStartOutbound()
|
||||
}
|
||||
} else {
|
||||
ArtemisState.RECEIVING
|
||||
}
|
||||
})
|
||||
} catch (ex: Exception) {
|
||||
// Now, bounce the AMQP connection to restart the sequence of establishing the connectivity back from the beginning.
|
||||
withMDC { log.warn("Create Artemis start session error. Restarting AMQP connection", ex) }
|
||||
stopAndStartOutbound()
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logInfoWithMDC("Bridge Disconnected")
|
||||
amqpRestartEvent?.cancel(false)
|
||||
if (artemisState != ArtemisState.AMQP_STARTING && artemisState != ArtemisState.STOPPED) {
|
||||
bridgeMetricsService?.bridgeDisconnected(targets, legalNames)
|
||||
}
|
||||
artemis(ArtemisState.STOPPING) { precedingState: ArtemisState ->
|
||||
logInfoWithMDC("Stopping Artemis because AMQP bridge disconnected")
|
||||
closeConsumer()
|
||||
consumer = null
|
||||
eventLoop.execute {
|
||||
artemis(ArtemisState.STOPPING) {
|
||||
stopSession()
|
||||
session = null
|
||||
when (precedingState) {
|
||||
ArtemisState.AMQP_STOPPED ->
|
||||
ArtemisState.STOPPED_AMQP_START_SCHEDULED(scheduledArtemis(artemisHeartbeatPlusBackoff,
|
||||
TimeUnit.MILLISECONDS, ArtemisState.AMQP_STARTING) { startOutbound() })
|
||||
ArtemisState.AMQP_RESTARTED -> {
|
||||
artemis(ArtemisState.AMQP_STARTING) { startOutbound() }
|
||||
ArtemisState.AMQP_STARTING
|
||||
}
|
||||
else -> ArtemisState.STOPPED
|
||||
}
|
||||
}
|
||||
}
|
||||
ArtemisState.STOPPING
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun startOutbound(): ArtemisState {
|
||||
logInfoWithMDC("Starting AMQP client")
|
||||
amqpClient.start()
|
||||
return ArtemisState.AMQP_STARTED
|
||||
}
|
||||
|
||||
private fun stopAndStartOutbound(state: ArtemisState = ArtemisState.AMQP_STOPPED): ArtemisState {
|
||||
amqpClient.stop()
|
||||
// Bridge disconnect will detect this state and schedule an AMQP start.
|
||||
return state
|
||||
}
|
||||
|
||||
private fun createSessionAndConsumer(startedArtemis: ArtemisMessagingClient.Started): ClientSession {
|
||||
logInfoWithMDC("Creating session and consumer.")
|
||||
val sessionFactory = startedArtemis.sessionFactory
|
||||
val session = sessionFactory.createSession(NODE_P2P_USER, NODE_P2P_USER, false, true,
|
||||
true, false, DEFAULT_ACK_BATCH_SIZE)
|
||||
this.session = session
|
||||
// Several producers (in the case of shared bridge) can put messages in the same outbound p2p queue.
|
||||
// The consumers are created using the source x500 name as a filter
|
||||
val consumer = if (amqpConfig.enableSNI) {
|
||||
session.createConsumer(queueName, "hyphenated_props:sender-subject-name = '${amqpConfig.sourceX500Name}'")
|
||||
} else {
|
||||
session.createConsumer(queueName)
|
||||
}
|
||||
this.consumer = consumer
|
||||
session.start()
|
||||
consumer.setMessageHandler(this@AMQPBridge::clientArtemisMessageHandler)
|
||||
return session
|
||||
}
|
||||
|
||||
private fun closeConsumer(): Boolean {
|
||||
var closed = false
|
||||
try {
|
||||
consumer?.apply {
|
||||
if (!isClosed) {
|
||||
close()
|
||||
}
|
||||
}
|
||||
closed = true
|
||||
} catch (ex: Exception) {
|
||||
withMDC { log.warn("Close artemis consumer error", ex) }
|
||||
} finally {
|
||||
return closed
|
||||
}
|
||||
}
|
||||
|
||||
private fun stopSession(): Boolean {
|
||||
var stopped = false
|
||||
try {
|
||||
session?.apply {
|
||||
if (!isClosed) {
|
||||
stop()
|
||||
}
|
||||
}
|
||||
stopped = true
|
||||
} catch (ex: Exception) {
|
||||
withMDC { log.warn("Stop Artemis session error", ex) }
|
||||
} finally {
|
||||
return stopped
|
||||
}
|
||||
}
|
||||
|
||||
private fun restartSession(): Boolean {
|
||||
if (!stopSession()) {
|
||||
// Session timed out stopping. The request/responses can be out of sequence on the session now, so abandon it.
|
||||
session = null
|
||||
// The consumer is also dead now too as attached to the dead session.
|
||||
consumer = null
|
||||
return false
|
||||
}
|
||||
try {
|
||||
// Does not wait for a response.
|
||||
this.session?.start()
|
||||
} catch (ex: Exception) {
|
||||
withMDC { log.error("Start Artemis session error", ex) }
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
private fun clientArtemisMessageHandler(artemisMessage: ClientMessage) {
|
||||
messagesReceived = true
|
||||
if (artemisMessage.bodySize > amqpConfig.maxMessageSize) {
|
||||
val msg = "Message exceeds maxMessageSize network parameter, maxMessageSize: [${amqpConfig.maxMessageSize}], message size: [${artemisMessage.bodySize}], " +
|
||||
"dropping message, uuid: ${artemisMessage.getObjectProperty("_AMQ_DUPL_ID")}"
|
||||
"dropping message, uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}"
|
||||
logWarnWithMDC(msg)
|
||||
bridgeMetricsService?.packetDropEvent(artemisMessage, msg)
|
||||
// Ack the message to prevent same message being sent to us again.
|
||||
artemisMessage.individualAcknowledge()
|
||||
try {
|
||||
artemisMessage.individualAcknowledge()
|
||||
} catch (ex: ActiveMQObjectClosedException) {
|
||||
log.warn("Artemis message was closed")
|
||||
}
|
||||
return
|
||||
}
|
||||
val data = ByteArray(artemisMessage.bodySize).apply { artemisMessage.bodyBuffer.readBytes(this) }
|
||||
val properties = HashMap<String, Any?>()
|
||||
for (key in P2PMessagingHeaders.whitelistedHeaders) {
|
||||
if (artemisMessage.containsProperty(key)) {
|
||||
@ -178,18 +416,22 @@ class AMQPBridgeManager(config: MutualSslConfiguration,
|
||||
properties[key] = value
|
||||
}
|
||||
}
|
||||
logDebugWithMDC { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty("_AMQ_DUPL_ID")}" }
|
||||
logDebugWithMDC { "Bridged Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
|
||||
val peerInbox = translateLocalQueueToInboxAddress(queueName)
|
||||
val sendableMessage = amqpClient.createMessage(data, peerInbox,
|
||||
val sendableMessage = amqpClient.createMessage(artemisMessage.payload(), peerInbox,
|
||||
legalNames.first().toString(),
|
||||
properties)
|
||||
sendableMessage.onComplete.then {
|
||||
logDebugWithMDC { "Bridge ACK ${sendableMessage.onComplete.get()}" }
|
||||
lock.withLock {
|
||||
eventLoop.submit {
|
||||
if (sendableMessage.onComplete.get() == MessageStatus.Acknowledged) {
|
||||
artemisMessage.individualAcknowledge()
|
||||
try {
|
||||
artemisMessage.individualAcknowledge()
|
||||
} catch (ex: ActiveMQObjectClosedException) {
|
||||
log.warn("Artemis message was closed")
|
||||
}
|
||||
} else {
|
||||
logInfoWithMDC("Rollback rejected message uuid: ${artemisMessage.getObjectProperty("_AMQ_DUPL_ID")}")
|
||||
logInfoWithMDC("Rollback rejected message uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}")
|
||||
// We need to commit any acknowledged messages before rolling back the failed
|
||||
// (unacknowledged) message.
|
||||
session?.commit()
|
||||
@ -202,9 +444,9 @@ class AMQPBridgeManager(config: MutualSslConfiguration,
|
||||
} catch (ex: IllegalStateException) {
|
||||
// Attempting to send a message while the AMQP client is disconnected may cause message loss.
|
||||
// The failed message is rolled back after committing acknowledged messages.
|
||||
lock.withLock {
|
||||
ex.message?.let { logInfoWithMDC(it)}
|
||||
logInfoWithMDC("Rollback rejected message uuid: ${artemisMessage.getObjectProperty("_AMQ_DUPL_ID")}")
|
||||
eventLoop.submit {
|
||||
ex.message?.let { logInfoWithMDC(it) }
|
||||
logInfoWithMDC("Rollback rejected message uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}")
|
||||
session?.commit()
|
||||
session?.rollback(false)
|
||||
}
|
||||
@ -213,20 +455,22 @@ class AMQPBridgeManager(config: MutualSslConfiguration,
|
||||
}
|
||||
}
|
||||
|
||||
override fun deployBridge(queueName: String, targets: List<NetworkHostAndPort>, legalNames: Set<CordaX500Name>) {
|
||||
val newBridge = lock.withLock {
|
||||
override fun deployBridge(sourceX500Name: String, queueName: String, targets: List<NetworkHostAndPort>, legalNames: Set<CordaX500Name>) {
|
||||
lock.withLock {
|
||||
val bridges = queueNamesToBridgesMap.getOrPut(queueName) { mutableListOf() }
|
||||
for (target in targets) {
|
||||
if (bridges.any { it.targets.contains(target) }) {
|
||||
if (bridges.any { it.targets.contains(target) && it.sourceX500Name == sourceX500Name }) {
|
||||
return
|
||||
}
|
||||
}
|
||||
val newBridge = AMQPBridge(queueName, targets, legalNames, amqpConfig, sharedEventLoopGroup!!, artemis!!, bridgeMetricsService)
|
||||
val newAMQPConfig = with(amqpConfig) { AMQPConfigurationImpl(keyStore, trustStore, proxyConfig, maxMessageSize,
|
||||
revocationConfig, useOpenSsl, enableSNI, sourceX500Name, trace, sslHandshakeTimeout) }
|
||||
val newBridge = AMQPBridge(sourceX500Name, queueName, targets, legalNames, newAMQPConfig, sharedEventLoopGroup!!, artemis!!,
|
||||
bridgeMetricsService, bridgeConnectionTTLSeconds)
|
||||
bridges += newBridge
|
||||
bridgeMetricsService?.bridgeCreated(targets, legalNames)
|
||||
newBridge
|
||||
}
|
||||
newBridge.start()
|
||||
}.start()
|
||||
}
|
||||
|
||||
override fun destroyBridge(queueName: String, targets: List<NetworkHostAndPort>) {
|
||||
@ -246,6 +490,17 @@ class AMQPBridgeManager(config: MutualSslConfiguration,
|
||||
}
|
||||
}
|
||||
|
||||
fun destroyAllBridges(queueName: String): Map<String, BridgeEntry> {
|
||||
return lock.withLock {
|
||||
// queueNamesToBridgesMap returns a mutable list, .toList converts it to a immutable list so it won't be changed by the [destroyBridge] method.
|
||||
val bridges = queueNamesToBridgesMap[queueName]?.toList()
|
||||
destroyBridge(queueName, bridges?.flatMap { it.targets } ?: emptyList())
|
||||
bridges?.map {
|
||||
it.sourceX500Name to BridgeEntry(it.queueName, it.targets, it.legalNames.toList(), serviceAddress = false)
|
||||
}?.toMap() ?: emptyMap()
|
||||
}
|
||||
}
|
||||
|
||||
override fun start() {
|
||||
sharedEventLoopGroup = NioEventLoopGroup(NUM_BRIDGE_THREADS)
|
||||
val artemis = artemisMessageClientFactory()
|
||||
|
@ -1,51 +1,124 @@
|
||||
@file:Suppress("TooGenericExceptionCaught") // needs to catch and handle/rethrow *all* exceptions
|
||||
package net.corda.nodeapi.internal.bridging
|
||||
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.serialization.SerializationDefaults
|
||||
import net.corda.core.serialization.deserialize
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.utilities.NetworkHostAndPort
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingClient
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_CONTROL
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_NOTIFY
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEERS_PREFIX
|
||||
import net.corda.nodeapi.internal.ArtemisSessionProvider
|
||||
import net.corda.nodeapi.internal.config.CertificateStore
|
||||
import net.corda.nodeapi.internal.config.MutualSslConfiguration
|
||||
import net.corda.nodeapi.internal.crypto.x509
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
|
||||
import org.apache.activemq.artemis.api.core.ActiveMQNonExistentQueueException
|
||||
import org.apache.activemq.artemis.api.core.ActiveMQQueueExistsException
|
||||
import org.apache.activemq.artemis.api.core.RoutingType
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.apache.activemq.artemis.api.core.client.ClientConsumer
|
||||
import org.apache.activemq.artemis.api.core.client.ClientMessage
|
||||
import org.apache.activemq.artemis.api.core.client.ClientSession
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
import java.util.*
|
||||
|
||||
class BridgeControlListener(val config: MutualSslConfiguration,
|
||||
class BridgeControlListener(private val keyStore: CertificateStore,
|
||||
trustStore: CertificateStore,
|
||||
useOpenSSL: Boolean,
|
||||
proxyConfig: ProxyConfig? = null,
|
||||
maxMessageSize: Int,
|
||||
crlCheckSoftFail: Boolean,
|
||||
revocationConfig: RevocationConfig,
|
||||
enableSNI: Boolean,
|
||||
private val artemisMessageClientFactory: () -> ArtemisSessionProvider,
|
||||
bridgeMetricsService: BridgeMetricsService? = null) : AutoCloseable {
|
||||
bridgeMetricsService: BridgeMetricsService? = null,
|
||||
trace: Boolean = false,
|
||||
sslHandshakeTimeout: Long? = null,
|
||||
bridgeConnectionTTLSeconds: Int = 0) : AutoCloseable {
|
||||
private val bridgeId: String = UUID.randomUUID().toString()
|
||||
private val bridgeManager: BridgeManager = AMQPBridgeManager(
|
||||
config,
|
||||
maxMessageSize,
|
||||
crlCheckSoftFail,
|
||||
artemisMessageClientFactory,
|
||||
bridgeMetricsService)
|
||||
private var bridgeControlQueue = "$BRIDGE_CONTROL.$bridgeId"
|
||||
private var bridgeNotifyQueue = "$BRIDGE_NOTIFY.$bridgeId"
|
||||
private val validInboundQueues = mutableSetOf<String>()
|
||||
private val bridgeManager = if (enableSNI) {
|
||||
LoopbackBridgeManager(keyStore, trustStore, useOpenSSL, proxyConfig, maxMessageSize, revocationConfig, enableSNI,
|
||||
artemisMessageClientFactory, bridgeMetricsService, this::validateReceiveTopic, trace, sslHandshakeTimeout,
|
||||
bridgeConnectionTTLSeconds)
|
||||
} else {
|
||||
AMQPBridgeManager(keyStore, trustStore, useOpenSSL, proxyConfig, maxMessageSize, revocationConfig, enableSNI,
|
||||
artemisMessageClientFactory, bridgeMetricsService, trace, sslHandshakeTimeout, bridgeConnectionTTLSeconds)
|
||||
}
|
||||
private var artemis: ArtemisSessionProvider? = null
|
||||
private var controlConsumer: ClientConsumer? = null
|
||||
private var notifyConsumer: ClientConsumer? = null
|
||||
|
||||
constructor(config: MutualSslConfiguration,
|
||||
p2pAddress: NetworkHostAndPort,
|
||||
maxMessageSize: Int,
|
||||
revocationConfig: RevocationConfig,
|
||||
enableSNI: Boolean,
|
||||
proxy: ProxyConfig? = null) : this(config.keyStore.get(), config.trustStore.get(), config.useOpenSsl, proxy, maxMessageSize, revocationConfig, enableSNI, { ArtemisMessagingClient(config, p2pAddress, maxMessageSize) })
|
||||
|
||||
companion object {
|
||||
private val log = contextLogger()
|
||||
}
|
||||
|
||||
val active: Boolean
|
||||
get() = validInboundQueues.isNotEmpty()
|
||||
|
||||
private val _activeChange = PublishSubject.create<Boolean>().toSerialized()
|
||||
val activeChange: Observable<Boolean>
|
||||
get() = _activeChange
|
||||
|
||||
private val _failure = PublishSubject.create<BridgeControlListener>().toSerialized()
|
||||
val failure: Observable<BridgeControlListener>
|
||||
get() = _failure
|
||||
|
||||
fun start() {
|
||||
stop()
|
||||
bridgeManager.start()
|
||||
val artemis = artemisMessageClientFactory()
|
||||
this.artemis = artemis
|
||||
artemis.start()
|
||||
val artemisClient = artemis.started!!
|
||||
val artemisSession = artemisClient.session
|
||||
val bridgeControlQueue = "$BRIDGE_CONTROL.$bridgeId"
|
||||
artemisSession.createTemporaryQueue(BRIDGE_CONTROL, RoutingType.MULTICAST, bridgeControlQueue)
|
||||
try {
|
||||
stop()
|
||||
|
||||
val queueDisambiguityId = UUID.randomUUID().toString()
|
||||
bridgeControlQueue = "$BRIDGE_CONTROL.$queueDisambiguityId"
|
||||
bridgeNotifyQueue = "$BRIDGE_NOTIFY.$queueDisambiguityId"
|
||||
|
||||
bridgeManager.start()
|
||||
val artemis = artemisMessageClientFactory()
|
||||
this.artemis = artemis
|
||||
artemis.start()
|
||||
val artemisClient = artemis.started!!
|
||||
val artemisSession = artemisClient.session
|
||||
registerBridgeControlListener(artemisSession)
|
||||
registerBridgeDuplicateChecker(artemisSession)
|
||||
// Attempt to read available inboxes directly from Artemis before requesting updates from connected nodes
|
||||
validInboundQueues.addAll(artemisSession.addressQuery(SimpleString("$P2P_PREFIX#")).queueNames.map { it.toString() })
|
||||
log.info("Found inboxes: $validInboundQueues")
|
||||
if (active) {
|
||||
_activeChange.onNext(true)
|
||||
}
|
||||
val startupMessage = BridgeControl.BridgeToNodeSnapshotRequest(bridgeId).serialize(context = SerializationDefaults.P2P_CONTEXT)
|
||||
.bytes
|
||||
val bridgeRequest = artemisSession.createMessage(false)
|
||||
bridgeRequest.writeBodyBufferBytes(startupMessage)
|
||||
artemisClient.producer.send(BRIDGE_NOTIFY, bridgeRequest)
|
||||
} catch (e: Exception) {
|
||||
log.error("Failure to start BridgeControlListener", e)
|
||||
_failure.onNext(this)
|
||||
}
|
||||
}
|
||||
|
||||
private fun registerBridgeControlListener(artemisSession: ClientSession) {
|
||||
try {
|
||||
artemisSession.createTemporaryQueue(BRIDGE_CONTROL, RoutingType.MULTICAST, bridgeControlQueue)
|
||||
} catch (ex: ActiveMQQueueExistsException) {
|
||||
// Ignore if there is a queue still not cleaned up
|
||||
}
|
||||
|
||||
val control = artemisSession.createConsumer(bridgeControlQueue)
|
||||
controlConsumer = control
|
||||
control.setMessageHandler { msg ->
|
||||
@ -53,22 +126,64 @@ class BridgeControlListener(val config: MutualSslConfiguration,
|
||||
processControlMessage(msg)
|
||||
} catch (ex: Exception) {
|
||||
log.error("Unable to process bridge control message", ex)
|
||||
_failure.onNext(this)
|
||||
}
|
||||
msg.acknowledge()
|
||||
}
|
||||
}
|
||||
|
||||
private fun registerBridgeDuplicateChecker(artemisSession: ClientSession) {
|
||||
try {
|
||||
artemisSession.createTemporaryQueue(BRIDGE_NOTIFY, RoutingType.MULTICAST, bridgeNotifyQueue)
|
||||
} catch (ex: ActiveMQQueueExistsException) {
|
||||
// Ignore if there is a queue still not cleaned up
|
||||
}
|
||||
val notify = artemisSession.createConsumer(bridgeNotifyQueue)
|
||||
notifyConsumer = notify
|
||||
notify.setMessageHandler { msg ->
|
||||
try {
|
||||
val data: ByteArray = ByteArray(msg.bodySize).apply { msg.bodyBuffer.readBytes(this) }
|
||||
val notifyMessage = data.deserialize<BridgeControl.BridgeToNodeSnapshotRequest>(context = SerializationDefaults.P2P_CONTEXT)
|
||||
if (notifyMessage.bridgeIdentity != bridgeId) {
|
||||
log.error("Fatal Error! Two bridges have been configured simultaneously! Check the enterpriseConfiguration.externalBridge status")
|
||||
System.exit(1)
|
||||
}
|
||||
} catch (ex: Exception) {
|
||||
log.error("Unable to process bridge notification message", ex)
|
||||
_failure.onNext(this)
|
||||
}
|
||||
msg.acknowledge()
|
||||
}
|
||||
val startupMessage = BridgeControl.BridgeToNodeSnapshotRequest(bridgeId).serialize(context = SerializationDefaults.P2P_CONTEXT).bytes
|
||||
val bridgeRequest = artemisSession.createMessage(false)
|
||||
bridgeRequest.writeBodyBufferBytes(startupMessage)
|
||||
artemisClient.producer.send(BRIDGE_NOTIFY, bridgeRequest)
|
||||
}
|
||||
|
||||
fun stop() {
|
||||
validInboundQueues.clear()
|
||||
controlConsumer?.close()
|
||||
controlConsumer = null
|
||||
artemis?.stop()
|
||||
artemis = null
|
||||
bridgeManager.stop()
|
||||
try {
|
||||
if (active) {
|
||||
_activeChange.onNext(false)
|
||||
}
|
||||
validInboundQueues.clear()
|
||||
controlConsumer?.close()
|
||||
controlConsumer = null
|
||||
notifyConsumer?.close()
|
||||
notifyConsumer = null
|
||||
artemis?.apply {
|
||||
try {
|
||||
started?.session?.deleteQueue(bridgeControlQueue)
|
||||
} catch (e: ActiveMQNonExistentQueueException) {
|
||||
log.warn("Queue $bridgeControlQueue does not exist and it can't be deleted")
|
||||
}
|
||||
try {
|
||||
started?.session?.deleteQueue(bridgeNotifyQueue)
|
||||
} catch (e: ActiveMQNonExistentQueueException) {
|
||||
log.warn("Queue $bridgeNotifyQueue does not exist and it can't be deleted")
|
||||
}
|
||||
stop()
|
||||
}
|
||||
artemis = null
|
||||
bridgeManager.stop()
|
||||
} catch (e: Exception) {
|
||||
log.error("Failure to stop BridgeControlListener", e)
|
||||
}
|
||||
}
|
||||
|
||||
override fun close() = stop()
|
||||
@ -91,6 +206,10 @@ class BridgeControlListener(val config: MutualSslConfiguration,
|
||||
log.info("Received bridge control message $controlMessage")
|
||||
when (controlMessage) {
|
||||
is BridgeControl.NodeToBridgeSnapshot -> {
|
||||
if (!isConfigured(controlMessage.nodeIdentity)) {
|
||||
log.error("Fatal error! Bridge not configured with keystore for node with legal name ${controlMessage.nodeIdentity}.")
|
||||
System.exit(1)
|
||||
}
|
||||
if (!controlMessage.inboxQueues.all { validateInboxQueueName(it) }) {
|
||||
log.error("Invalid queue names in control message $controlMessage")
|
||||
return
|
||||
@ -99,10 +218,20 @@ class BridgeControlListener(val config: MutualSslConfiguration,
|
||||
log.error("Invalid queue names in control message $controlMessage")
|
||||
return
|
||||
}
|
||||
for (outQueue in controlMessage.sendQueues) {
|
||||
bridgeManager.deployBridge(outQueue.queueName, outQueue.targets, outQueue.legalNames.toSet())
|
||||
}
|
||||
|
||||
val wasActive = active
|
||||
validInboundQueues.addAll(controlMessage.inboxQueues)
|
||||
for (outQueue in controlMessage.sendQueues) {
|
||||
bridgeManager.deployBridge(controlMessage.nodeIdentity, outQueue.queueName, outQueue.targets, outQueue.legalNames.toSet())
|
||||
}
|
||||
log.info("Added inbox: ${controlMessage.inboxQueues}. Current inboxes: $validInboundQueues.")
|
||||
if (bridgeManager is LoopbackBridgeManager) {
|
||||
// Notify loopback bridge manager inboxes has changed.
|
||||
bridgeManager.inboxesAdded(controlMessage.inboxQueues)
|
||||
}
|
||||
if (!wasActive && active) {
|
||||
_activeChange.onNext(true)
|
||||
}
|
||||
}
|
||||
is BridgeControl.BridgeToNodeSnapshotRequest -> {
|
||||
log.error("Message from Bridge $controlMessage detected on wrong topic!")
|
||||
@ -112,7 +241,7 @@ class BridgeControlListener(val config: MutualSslConfiguration,
|
||||
log.error("Invalid queue names in control message $controlMessage")
|
||||
return
|
||||
}
|
||||
bridgeManager.deployBridge(controlMessage.bridgeInfo.queueName, controlMessage.bridgeInfo.targets, controlMessage.bridgeInfo.legalNames.toSet())
|
||||
bridgeManager.deployBridge(controlMessage.nodeIdentity, controlMessage.bridgeInfo.queueName, controlMessage.bridgeInfo.targets, controlMessage.bridgeInfo.legalNames.toSet())
|
||||
}
|
||||
is BridgeControl.Delete -> {
|
||||
if (!controlMessage.bridgeInfo.queueName.startsWith(PEERS_PREFIX)) {
|
||||
@ -121,7 +250,19 @@ class BridgeControlListener(val config: MutualSslConfiguration,
|
||||
}
|
||||
bridgeManager.destroyBridge(controlMessage.bridgeInfo.queueName, controlMessage.bridgeInfo.targets)
|
||||
}
|
||||
is BridgeControl.BridgeHealthCheck -> {
|
||||
log.warn("Not currently doing anything on BridgeHealthCheck")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun isConfigured(sourceX500Name: String): Boolean {
|
||||
val keyStore = keyStore.value.internal
|
||||
return keyStore.aliases().toList().any { alias ->
|
||||
val x500Name = keyStore.getCertificate(alias).x509.subjectX500Principal
|
||||
val cordaX500Name = CordaX500Name.build(x500Name)
|
||||
cordaX500Name.toString() == sourceX500Name
|
||||
}
|
||||
}
|
||||
}
|
@ -11,7 +11,7 @@ import net.corda.core.utilities.NetworkHostAndPort
|
||||
* @property legalNames The list of acceptable [CordaX500Name] names that should be presented as subject of the validated peer TLS certificate.
|
||||
*/
|
||||
@CordaSerializable
|
||||
data class BridgeEntry(val queueName: String, val targets: List<NetworkHostAndPort>, val legalNames: List<CordaX500Name>)
|
||||
data class BridgeEntry(val queueName: String, val targets: List<NetworkHostAndPort>, val legalNames: List<CordaX500Name>, val serviceAddress: Boolean)
|
||||
|
||||
sealed class BridgeControl {
|
||||
/**
|
||||
@ -47,4 +47,13 @@ sealed class BridgeControl {
|
||||
*/
|
||||
@CordaSerializable
|
||||
data class Delete(val nodeIdentity: String, val bridgeInfo: BridgeEntry) : BridgeControl()
|
||||
|
||||
/**
|
||||
* This message is sent to Bridge to check the health of it.
|
||||
* @property requestId The identifier for the health check request as health check is likely to be produced repeatedly.
|
||||
* @property command Allows to specify the sort fo health check that needs to be performed.
|
||||
* @property bridgeInfo The connection details of the new bridge (optional).
|
||||
*/
|
||||
@CordaSerializable
|
||||
data class BridgeHealthCheck(val requestId: Long, val command: String, val bridgeInfo: BridgeEntry?) : BridgeControl()
|
||||
}
|
@ -3,17 +3,20 @@ package net.corda.nodeapi.internal.bridging
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.internal.VisibleForTesting
|
||||
import net.corda.core.utilities.NetworkHostAndPort
|
||||
import org.apache.activemq.artemis.api.core.client.ClientMessage
|
||||
|
||||
/**
|
||||
* Provides an internal interface that the [BridgeControlListener] delegates to for Bridge activities.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
interface BridgeManager : AutoCloseable {
|
||||
fun deployBridge(queueName: String, targets: List<NetworkHostAndPort>, legalNames: Set<CordaX500Name>)
|
||||
fun deployBridge(sourceX500Name: String, queueName: String, targets: List<NetworkHostAndPort>, legalNames: Set<CordaX500Name>)
|
||||
|
||||
fun destroyBridge(queueName: String, targets: List<NetworkHostAndPort>)
|
||||
|
||||
fun start()
|
||||
|
||||
fun stop()
|
||||
}
|
||||
}
|
||||
|
||||
fun ClientMessage.payload() = ByteArray(bodySize).apply { bodyBuffer.readBytes(this) }
|
@ -0,0 +1,223 @@
|
||||
package net.corda.nodeapi.internal.bridging
|
||||
|
||||
import net.corda.nodeapi.internal.ConcurrentBox
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.internal.VisibleForTesting
|
||||
import net.corda.core.utilities.NetworkHostAndPort
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_P2P_USER
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.RemoteInboxAddress.Companion.translateInboxAddressToLocalQueue
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.RemoteInboxAddress.Companion.translateLocalQueueToInboxAddress
|
||||
import net.corda.nodeapi.internal.ArtemisSessionProvider
|
||||
import net.corda.nodeapi.internal.ArtemisConstants.MESSAGE_ID_KEY
|
||||
import net.corda.nodeapi.internal.config.CertificateStore
|
||||
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.ProxyConfig
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.RevocationConfig
|
||||
import net.corda.nodeapi.internal.stillOpen
|
||||
import org.apache.activemq.artemis.api.core.SimpleString
|
||||
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
|
||||
import org.apache.activemq.artemis.api.core.client.ClientConsumer
|
||||
import org.apache.activemq.artemis.api.core.client.ClientMessage
|
||||
import org.apache.activemq.artemis.api.core.client.ClientProducer
|
||||
import org.apache.activemq.artemis.api.core.client.ClientSession
|
||||
import org.slf4j.MDC
|
||||
|
||||
/**
|
||||
* The LoopbackBridgeManager holds the list of independent LoopbackBridge objects that actively loopback messages to local Artemis
|
||||
* inboxes.
|
||||
*/
|
||||
@VisibleForTesting
|
||||
class LoopbackBridgeManager(keyStore: CertificateStore,
|
||||
trustStore: CertificateStore,
|
||||
useOpenSSL: Boolean,
|
||||
proxyConfig: ProxyConfig? = null,
|
||||
maxMessageSize: Int,
|
||||
revocationConfig: RevocationConfig,
|
||||
enableSNI: Boolean,
|
||||
private val artemisMessageClientFactory: () -> ArtemisSessionProvider,
|
||||
private val bridgeMetricsService: BridgeMetricsService? = null,
|
||||
private val isLocalInbox: (String) -> Boolean,
|
||||
trace: Boolean,
|
||||
sslHandshakeTimeout: Long? = null,
|
||||
bridgeConnectionTTLSeconds: Int = 0) : AMQPBridgeManager(keyStore, trustStore, useOpenSSL, proxyConfig,
|
||||
maxMessageSize, revocationConfig, enableSNI,
|
||||
artemisMessageClientFactory, bridgeMetricsService,
|
||||
trace, sslHandshakeTimeout,
|
||||
bridgeConnectionTTLSeconds) {
|
||||
|
||||
companion object {
|
||||
private val log = contextLogger()
|
||||
}
|
||||
|
||||
private val queueNamesToBridgesMap = ConcurrentBox(mutableMapOf<String, MutableList<LoopbackBridge>>())
|
||||
private var artemis: ArtemisSessionProvider? = null
|
||||
|
||||
/**
|
||||
* Each LoopbackBridge is an independent consumer of messages from the Artemis local queue per designated endpoint.
|
||||
* It attempts to loopback these messages via ArtemisClient to the local inbox.
|
||||
*/
|
||||
private class LoopbackBridge(val sourceX500Name: String,
|
||||
val queueName: String,
|
||||
val targets: List<NetworkHostAndPort>,
|
||||
val legalNames: Set<CordaX500Name>,
|
||||
artemis: ArtemisSessionProvider,
|
||||
private val bridgeMetricsService: BridgeMetricsService?) {
|
||||
companion object {
|
||||
private val log = contextLogger()
|
||||
}
|
||||
|
||||
// TODO: refactor MDC support, duplicated in AMQPBridgeManager.
|
||||
private fun withMDC(block: () -> Unit) {
|
||||
val oldMDC = MDC.getCopyOfContextMap()
|
||||
try {
|
||||
MDC.put("queueName", queueName)
|
||||
MDC.put("source", sourceX500Name)
|
||||
MDC.put("targets", targets.joinToString(separator = ";") { it.toString() })
|
||||
MDC.put("legalNames", legalNames.joinToString(separator = ";") { it.toString() })
|
||||
MDC.put("bridgeType", "loopback")
|
||||
block()
|
||||
} finally {
|
||||
MDC.setContextMap(oldMDC)
|
||||
}
|
||||
}
|
||||
|
||||
private fun logDebugWithMDC(msg: () -> String) {
|
||||
if (log.isDebugEnabled) {
|
||||
withMDC { log.debug(msg()) }
|
||||
}
|
||||
}
|
||||
|
||||
private fun logInfoWithMDC(msg: String) = withMDC { log.info(msg) }
|
||||
|
||||
private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) }
|
||||
|
||||
private val artemis = ConcurrentBox(artemis)
|
||||
private var consumerSession: ClientSession? = null
|
||||
private var producerSession: ClientSession? = null
|
||||
private var consumer: ClientConsumer? = null
|
||||
private var producer: ClientProducer? = null
|
||||
|
||||
fun start() {
|
||||
logInfoWithMDC("Create new Artemis loopback bridge")
|
||||
artemis.exclusive {
|
||||
logInfoWithMDC("Bridge Connected")
|
||||
bridgeMetricsService?.bridgeConnected(targets, legalNames)
|
||||
val sessionFactory = started!!.sessionFactory
|
||||
this@LoopbackBridge.consumerSession = sessionFactory.createSession(NODE_P2P_USER, NODE_P2P_USER, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
|
||||
this@LoopbackBridge.producerSession = sessionFactory.createSession(NODE_P2P_USER, NODE_P2P_USER, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
|
||||
// Several producers (in the case of shared bridge) can put messages in the same outbound p2p queue. The consumers are created using the source x500 name as a filter
|
||||
val consumer = consumerSession!!.createConsumer(queueName, "hyphenated_props:sender-subject-name = '$sourceX500Name'")
|
||||
consumer.setMessageHandler(this@LoopbackBridge::clientArtemisMessageHandler)
|
||||
this@LoopbackBridge.consumer = consumer
|
||||
this@LoopbackBridge.producer = producerSession!!.createProducer()
|
||||
consumerSession?.start()
|
||||
producerSession?.start()
|
||||
}
|
||||
}
|
||||
|
||||
fun stop() {
|
||||
logInfoWithMDC("Stopping AMQP bridge")
|
||||
artemis.exclusive {
|
||||
bridgeMetricsService?.bridgeDisconnected(targets, legalNames)
|
||||
consumer?.apply { if (!isClosed) close() }
|
||||
consumer = null
|
||||
producer?.apply { if (!isClosed) close() }
|
||||
producer = null
|
||||
consumerSession?.apply { if (stillOpen()) stop() }
|
||||
consumerSession = null
|
||||
producerSession?.apply { if (stillOpen()) stop()}
|
||||
producerSession = null
|
||||
}
|
||||
}
|
||||
|
||||
private fun clientArtemisMessageHandler(artemisMessage: ClientMessage) {
|
||||
logDebugWithMDC { "Loopback Send to ${legalNames.first()} uuid: ${artemisMessage.getObjectProperty(MESSAGE_ID_KEY)}" }
|
||||
val peerInbox = translateLocalQueueToInboxAddress(queueName)
|
||||
producer?.send(SimpleString(peerInbox), artemisMessage) { artemisMessage.individualAcknowledge() }
|
||||
bridgeMetricsService?.let { metricsService ->
|
||||
val properties = ArtemisMessagingComponent.Companion.P2PMessagingHeaders.whitelistedHeaders.mapNotNull { key ->
|
||||
if (artemisMessage.containsProperty(key)) {
|
||||
key to artemisMessage.getObjectProperty(key).let { (it as? SimpleString)?.toString() ?: it }
|
||||
} else {
|
||||
null
|
||||
}
|
||||
}.toMap()
|
||||
metricsService.packetAcceptedEvent(SendableMessageImpl(artemisMessage.payload(), peerInbox, legalNames.first().toString(), targets.first(), properties))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun deployBridge(sourceX500Name: String, queueName: String, targets: List<NetworkHostAndPort>, legalNames: Set<CordaX500Name>) {
|
||||
val inboxAddress = translateLocalQueueToInboxAddress(queueName)
|
||||
if (isLocalInbox(inboxAddress)) {
|
||||
log.info("Deploying loopback bridge for $queueName, source $sourceX500Name")
|
||||
queueNamesToBridgesMap.exclusive {
|
||||
val bridges = getOrPut(queueName) { mutableListOf() }
|
||||
for (target in targets) {
|
||||
if (bridges.any { it.targets.contains(target) && it.sourceX500Name == sourceX500Name }) {
|
||||
return
|
||||
}
|
||||
}
|
||||
val newBridge = LoopbackBridge(sourceX500Name, queueName, targets, legalNames, artemis!!, bridgeMetricsService)
|
||||
bridges += newBridge
|
||||
bridgeMetricsService?.bridgeCreated(targets, legalNames)
|
||||
newBridge
|
||||
}.start()
|
||||
} else {
|
||||
log.info("Deploying AMQP bridge for $queueName, source $sourceX500Name")
|
||||
super.deployBridge(sourceX500Name, queueName, targets, legalNames)
|
||||
}
|
||||
}
|
||||
|
||||
override fun destroyBridge(queueName: String, targets: List<NetworkHostAndPort>) {
|
||||
super.destroyBridge(queueName, targets)
|
||||
queueNamesToBridgesMap.exclusive {
|
||||
val bridges = this[queueName] ?: mutableListOf()
|
||||
for (target in targets) {
|
||||
val bridge = bridges.firstOrNull { it.targets.contains(target) }
|
||||
if (bridge != null) {
|
||||
bridges -= bridge
|
||||
if (bridges.isEmpty()) {
|
||||
remove(queueName)
|
||||
}
|
||||
bridge.stop()
|
||||
bridgeMetricsService?.bridgeDestroyed(bridge.targets, bridge.legalNames)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove any AMQP bridge for the local inbox and create a loopback bridge for that queue.
|
||||
*/
|
||||
fun inboxesAdded(inboxes: List<String>) {
|
||||
for (inbox in inboxes) {
|
||||
super.destroyAllBridges(translateInboxAddressToLocalQueue(inbox)).forEach { source, bridgeEntry ->
|
||||
log.info("Destroyed AMQP Bridge '${bridgeEntry.queueName}', creating Loopback bridge for local inbox.")
|
||||
deployBridge(source, bridgeEntry.queueName, bridgeEntry.targets, bridgeEntry.legalNames.toSet())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun start() {
|
||||
super.start()
|
||||
val artemis = artemisMessageClientFactory()
|
||||
this.artemis = artemis
|
||||
artemis.start()
|
||||
}
|
||||
|
||||
override fun stop() = close()
|
||||
|
||||
override fun close() {
|
||||
super.close()
|
||||
queueNamesToBridgesMap.exclusive {
|
||||
for (bridge in values.flatten()) {
|
||||
bridge.stop()
|
||||
}
|
||||
clear()
|
||||
artemis?.stop()
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,63 @@
|
||||
package net.corda.nodeapi.internal.config
|
||||
|
||||
import net.corda.core.utilities.minutes
|
||||
import net.corda.core.utilities.seconds
|
||||
import java.time.Duration
|
||||
|
||||
/**
|
||||
* Predefined connection configurations used by Artemis clients (currently used in the P2P messaging layer).
|
||||
* The enum names represent the approximate total duration of the failover (with exponential back-off). The formula used to calculate
|
||||
* this duration is as follows:
|
||||
*
|
||||
* totalFailoverDuration = SUM(k=0 to [reconnectAttempts]) of [retryInterval] * POW([retryIntervalMultiplier], k)
|
||||
*
|
||||
* Example calculation for [DEFAULT]:
|
||||
*
|
||||
* totalFailoverDuration = 5 + 5 * 1.5 + 5 * (1.5)^2 + 5 * (1.5)^3 + 5 * (1.5)^4 = ~66 seconds
|
||||
*
|
||||
* @param failoverOnInitialAttempt Determines whether failover is triggered if initial connection fails.
|
||||
* @param initialConnectAttempts The number of reconnect attempts if failover is enabled for initial connection. A value
|
||||
* of -1 represents infinite attempts.
|
||||
* @param reconnectAttempts The number of reconnect attempts for failover after initial connection is done. A value
|
||||
* of -1 represents infinite attempts.
|
||||
* @param retryInterval Duration between reconnect attempts.
|
||||
* @param retryIntervalMultiplier Value used in the reconnection back-off process.
|
||||
* @param maxRetryInterval Determines the maximum duration between reconnection attempts. Useful when using infinite retries.
|
||||
*/
|
||||
enum class MessagingServerConnectionConfiguration {
|
||||
|
||||
DEFAULT {
|
||||
override fun failoverOnInitialAttempt(isHa: Boolean) = true
|
||||
override fun initialConnectAttempts(isHa: Boolean) = 5
|
||||
override fun reconnectAttempts(isHa: Boolean) = 5
|
||||
override fun retryInterval() = 5.seconds
|
||||
override fun retryIntervalMultiplier() = 1.5
|
||||
override fun maxRetryInterval(isHa: Boolean) = 3.minutes
|
||||
},
|
||||
|
||||
FAIL_FAST {
|
||||
override fun failoverOnInitialAttempt(isHa: Boolean) = isHa
|
||||
override fun initialConnectAttempts(isHa: Boolean) = 0
|
||||
// Client die too fast during failover/failback, need a few reconnect attempts to allow new master to become active
|
||||
override fun reconnectAttempts(isHa: Boolean) = if (isHa) 3 else 0
|
||||
override fun retryInterval() = 5.seconds
|
||||
override fun retryIntervalMultiplier() = 1.5
|
||||
override fun maxRetryInterval(isHa: Boolean) = 3.minutes
|
||||
},
|
||||
|
||||
CONTINUOUS_RETRY {
|
||||
override fun failoverOnInitialAttempt(isHa: Boolean) = true
|
||||
override fun initialConnectAttempts(isHa: Boolean) = if (isHa) 0 else -1
|
||||
override fun reconnectAttempts(isHa: Boolean) = -1
|
||||
override fun retryInterval() = 5.seconds
|
||||
override fun retryIntervalMultiplier() = 1.5
|
||||
override fun maxRetryInterval(isHa: Boolean) = if (isHa) 3.minutes else 5.minutes
|
||||
};
|
||||
|
||||
abstract fun failoverOnInitialAttempt(isHa: Boolean): Boolean
|
||||
abstract fun initialConnectAttempts(isHa: Boolean): Int
|
||||
abstract fun reconnectAttempts(isHa: Boolean): Int
|
||||
abstract fun retryInterval(): Duration
|
||||
abstract fun retryIntervalMultiplier(): Double
|
||||
abstract fun maxRetryInterval(isHa: Boolean): Duration
|
||||
}
|
@ -143,4 +143,4 @@ enum class WrappingMode {
|
||||
WRAPPED
|
||||
}
|
||||
|
||||
class WrappedPrivateKey(val keyMaterial: ByteArray, val signatureScheme: SignatureScheme)
|
||||
class WrappedPrivateKey(val keyMaterial: ByteArray, val signatureScheme: SignatureScheme, val encodingVersion: Int? = null)
|
@ -162,20 +162,30 @@ class BCCryptoService(private val legalName: X500Principal,
|
||||
wrappingKeyStore.save(wrappingKeyStorePath!!, certificateStore.password)
|
||||
}
|
||||
|
||||
/**
|
||||
* Using "AESWRAPPAD" cipher spec for key wrapping defined by [RFC 5649](https://tools.ietf.org/html/rfc5649).
|
||||
* "AESWRAPPAD" (same as "AESKWP" or "AESRFC5649WRAP") is implemented in [org.bouncycastle.jcajce.provider.symmetric.AES.WrapPad] using
|
||||
* [org.bouncycastle.crypto.engines.RFC5649WrapEngine]. See:
|
||||
* - https://www.bouncycastle.org/docs/docs1.5on/org/bouncycastle/crypto/engines/AESWrapPadEngine.html
|
||||
* - https://www.bouncycastle.org/docs/docs1.5on/org/bouncycastle/crypto/engines/RFC5649WrapEngine.html
|
||||
*
|
||||
* Keys encoded with "AESWRAPPAD" are stored with encodingVersion = 1. Previously used cipher spec ("AES" == "AES/ECB/PKCS5Padding")
|
||||
* corresponds to encodingVersion = null.
|
||||
*/
|
||||
override fun generateWrappedKeyPair(masterKeyAlias: String, childKeyScheme: SignatureScheme): Pair<PublicKey, WrappedPrivateKey> {
|
||||
if (!wrappingKeyStore.containsAlias(masterKeyAlias)) {
|
||||
throw IllegalStateException("There is no master key under the alias: $masterKeyAlias")
|
||||
}
|
||||
|
||||
val wrappingKey = wrappingKeyStore.getKey(masterKeyAlias, certificateStore.entryPassword.toCharArray())
|
||||
val cipher = Cipher.getInstance("AES", cordaBouncyCastleProvider)
|
||||
val cipher = Cipher.getInstance("AESWRAPPAD", cordaBouncyCastleProvider)
|
||||
cipher.init(Cipher.WRAP_MODE, wrappingKey)
|
||||
|
||||
val keyPairGenerator = keyPairGeneratorFromScheme(childKeyScheme)
|
||||
val keyPair = keyPairGenerator.generateKeyPair()
|
||||
val privateKeyMaterialWrapped = cipher.wrap(keyPair.private)
|
||||
|
||||
return Pair(keyPair.public, WrappedPrivateKey(privateKeyMaterialWrapped, childKeyScheme))
|
||||
return Pair(keyPair.public, WrappedPrivateKey(privateKeyMaterialWrapped, childKeyScheme, encodingVersion = 1))
|
||||
}
|
||||
|
||||
override fun sign(masterKeyAlias: String, wrappedPrivateKey: WrappedPrivateKey, payloadToSign: ByteArray): ByteArray {
|
||||
@ -184,7 +194,12 @@ class BCCryptoService(private val legalName: X500Principal,
|
||||
}
|
||||
|
||||
val wrappingKey = wrappingKeyStore.getKey(masterKeyAlias, certificateStore.entryPassword.toCharArray())
|
||||
val cipher = Cipher.getInstance("AES", cordaBouncyCastleProvider)
|
||||
// Keeping backwards compatibility with previous encoding algorithms
|
||||
val algorithm = when(wrappedPrivateKey.encodingVersion) {
|
||||
1 -> "AESWRAPPAD"
|
||||
else -> "AES"
|
||||
}
|
||||
val cipher = Cipher.getInstance(algorithm, cordaBouncyCastleProvider)
|
||||
cipher.init(Cipher.UNWRAP_MODE, wrappingKey)
|
||||
|
||||
val privateKey = cipher.unwrap(wrappedPrivateKey.keyMaterial, keyAlgorithmFromScheme(wrappedPrivateKey.signatureScheme), Cipher.PRIVATE_KEY) as PrivateKey
|
||||
|
@ -98,11 +98,12 @@ class NodeLifecycleEventsDistributor : Closeable {
|
||||
orderedSnapshot.forEach {
|
||||
log.debug("Distributing event $event to: $it")
|
||||
val updateResult = it.update(event)
|
||||
if (updateResult.isSuccess) {
|
||||
log.debug("Event $event distribution outcome: $updateResult")
|
||||
} else {
|
||||
log.error("Failed to distribute event $event, failure outcome: $updateResult")
|
||||
handlePossibleFatalTermination(event, updateResult as Try.Failure<String>)
|
||||
when(updateResult) {
|
||||
is Try.Success -> log.debug("Event $event distribution outcome: $updateResult")
|
||||
is Try.Failure -> {
|
||||
log.error("Failed to distribute event $event, failure outcome: $updateResult", updateResult.exception)
|
||||
handlePossibleFatalTermination(event, updateResult)
|
||||
}
|
||||
}
|
||||
}
|
||||
result.set(null)
|
||||
|
@ -2,9 +2,9 @@ package net.corda.nodeapi.internal.network
|
||||
|
||||
import com.typesafe.config.Config
|
||||
import net.corda.common.configuration.parsing.internal.Configuration
|
||||
import net.corda.common.configuration.parsing.internal.get
|
||||
import net.corda.common.configuration.parsing.internal.mapValid
|
||||
import net.corda.common.configuration.parsing.internal.nested
|
||||
import net.corda.common.configuration.parsing.internal.withOptions
|
||||
import net.corda.common.validation.internal.Validated
|
||||
import net.corda.core.internal.noPackageOverlap
|
||||
import net.corda.core.internal.requirePackageValid
|
||||
@ -17,7 +17,7 @@ import java.security.KeyStoreException
|
||||
|
||||
typealias Valid<TARGET> = Validated<TARGET, Configuration.Validation.Error>
|
||||
|
||||
fun Config.parseAsNetworkParametersConfiguration(options: Configuration.Validation.Options = Configuration.Validation.Options(strict = false)):
|
||||
fun Config.parseAsNetworkParametersConfiguration(options: Configuration.Options = Configuration.Options.defaults):
|
||||
Valid<NetworkParametersOverrides> = NetworkParameterOverridesSpec.parse(this, options)
|
||||
|
||||
internal fun <T> badValue(msg: String): Valid<T> = Validated.invalid(sequenceOf(Configuration.Validation.Error.BadValue.of(msg)).toSet())
|
||||
@ -36,11 +36,12 @@ internal object NetworkParameterOverridesSpec : Configuration.Specification<Netw
|
||||
private val keystorePassword by string()
|
||||
private val keystoreAlias by string()
|
||||
|
||||
override fun parseValid(configuration: Config): Validated<PackageOwner, Configuration.Validation.Error> {
|
||||
val suppliedKeystorePath = configuration[keystore]
|
||||
val keystorePassword = configuration[keystorePassword]
|
||||
override fun parseValid(configuration: Config, options: Configuration.Options): Validated<PackageOwner, Configuration.Validation.Error> {
|
||||
val config = configuration.withOptions(options)
|
||||
val suppliedKeystorePath = config[keystore]
|
||||
val keystorePassword = config[keystorePassword]
|
||||
return try {
|
||||
val javaPackageName = configuration[packageName]
|
||||
val javaPackageName = config[packageName]
|
||||
val absoluteKeystorePath = if (suppliedKeystorePath.isAbsolute) {
|
||||
suppliedKeystorePath
|
||||
} else {
|
||||
@ -49,10 +50,10 @@ internal object NetworkParameterOverridesSpec : Configuration.Specification<Netw
|
||||
}.toAbsolutePath()
|
||||
val ks = loadKeyStore(absoluteKeystorePath, keystorePassword)
|
||||
return try {
|
||||
val publicKey = ks.getCertificate(configuration[keystoreAlias]).publicKey
|
||||
val publicKey = ks.getCertificate(config[keystoreAlias]).publicKey
|
||||
valid(PackageOwner(javaPackageName, publicKey))
|
||||
} catch (kse: KeyStoreException) {
|
||||
badValue("Keystore has not been initialized for alias ${configuration[keystoreAlias]}.")
|
||||
badValue("Keystore has not been initialized for alias ${config[keystoreAlias]}")
|
||||
}
|
||||
} catch (kse: KeyStoreException) {
|
||||
badValue("Password is incorrect or the key store is damaged for keyStoreFilePath: $suppliedKeystorePath.")
|
||||
@ -79,8 +80,9 @@ internal object NetworkParameterOverridesSpec : Configuration.Specification<Netw
|
||||
}
|
||||
}
|
||||
|
||||
override fun parseValid(configuration: Config): Valid<NetworkParametersOverrides> {
|
||||
val packageOwnership = configuration[packageOwnership]
|
||||
override fun parseValid(configuration: Config, options: Configuration.Options): Valid<NetworkParametersOverrides> {
|
||||
val config = configuration.withOptions(options)
|
||||
val packageOwnership = config[packageOwnership]
|
||||
if (packageOwnership != null && !noPackageOverlap(packageOwnership.map { it.javaPackageName })) {
|
||||
return Validated.invalid(sequenceOf(Configuration.Validation.Error.BadValue.of(
|
||||
"Package namespaces must not overlap",
|
||||
@ -89,11 +91,11 @@ internal object NetworkParameterOverridesSpec : Configuration.Specification<Netw
|
||||
)).toSet())
|
||||
}
|
||||
return valid(NetworkParametersOverrides(
|
||||
minimumPlatformVersion = configuration[minimumPlatformVersion],
|
||||
maxMessageSize = configuration[maxMessageSize],
|
||||
maxTransactionSize = configuration[maxTransactionSize],
|
||||
minimumPlatformVersion = config[minimumPlatformVersion],
|
||||
maxMessageSize = config[maxMessageSize],
|
||||
maxTransactionSize = config[maxTransactionSize],
|
||||
packageOwnership = packageOwnership,
|
||||
eventHorizon = configuration[eventHorizon]
|
||||
eventHorizon = config[eventHorizon]
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -214,14 +214,15 @@ class CordaPersistence(
|
||||
* @param isolationLevel isolation level for the transaction.
|
||||
* @param statement to be executed in the scope of this transaction.
|
||||
*/
|
||||
fun <T> transaction(isolationLevel: TransactionIsolationLevel, statement: DatabaseTransaction.() -> T): T =
|
||||
transaction(isolationLevel, 2, false, statement)
|
||||
fun <T> transaction(isolationLevel: TransactionIsolationLevel, useErrorHandler: Boolean, statement: DatabaseTransaction.() -> T): T =
|
||||
transaction(isolationLevel, 2, false, useErrorHandler, statement)
|
||||
|
||||
/**
|
||||
* Executes given statement in the scope of transaction with the transaction level specified at the creation time.
|
||||
* @param statement to be executed in the scope of this transaction.
|
||||
*/
|
||||
fun <T> transaction(statement: DatabaseTransaction.() -> T): T = transaction(defaultIsolationLevel, statement)
|
||||
@JvmOverloads
|
||||
fun <T> transaction(useErrorHandler: Boolean = true, statement: DatabaseTransaction.() -> T): T = transaction(defaultIsolationLevel, useErrorHandler, statement)
|
||||
|
||||
/**
|
||||
* Executes given statement in the scope of transaction, with the given isolation level.
|
||||
@ -231,7 +232,7 @@ class CordaPersistence(
|
||||
* @param statement to be executed in the scope of this transaction.
|
||||
*/
|
||||
fun <T> transaction(isolationLevel: TransactionIsolationLevel, recoverableFailureTolerance: Int,
|
||||
recoverAnyNestedSQLException: Boolean, statement: DatabaseTransaction.() -> T): T {
|
||||
recoverAnyNestedSQLException: Boolean, useErrorHandler: Boolean, statement: DatabaseTransaction.() -> T): T {
|
||||
_contextDatabase.set(this)
|
||||
val outer = contextTransactionOrNull
|
||||
return if (outer != null) {
|
||||
@ -240,26 +241,34 @@ class CordaPersistence(
|
||||
// previously been created by the flow state machine in ActionExecutorImpl#executeCreateTransaction
|
||||
// b. exceptions coming out from top level transactions are already being handled in CordaPersistence#inTopLevelTransaction
|
||||
// i.e. roll back and close the transaction
|
||||
try {
|
||||
if(useErrorHandler) {
|
||||
outer.withErrorHandler(statement)
|
||||
} else {
|
||||
outer.statement()
|
||||
} catch (e: Exception) {
|
||||
if (e is SQLException || e is PersistenceException || e is HospitalizeFlowException) {
|
||||
outer.errorHandler(e)
|
||||
}
|
||||
throw e
|
||||
}
|
||||
} else {
|
||||
inTopLevelTransaction(isolationLevel, recoverableFailureTolerance, recoverAnyNestedSQLException, statement)
|
||||
}
|
||||
}
|
||||
|
||||
private fun <T> DatabaseTransaction.withErrorHandler(statement: DatabaseTransaction.() -> T): T {
|
||||
return try {
|
||||
statement()
|
||||
} catch (e: Exception) {
|
||||
if ((e is SQLException || e is PersistenceException || e is HospitalizeFlowException)) {
|
||||
errorHandler(e)
|
||||
}
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes given statement in the scope of transaction with the transaction level specified at the creation time.
|
||||
* @param statement to be executed in the scope of this transaction.
|
||||
* @param recoverableFailureTolerance number of transaction commit retries for SQL while SQL exception is encountered.
|
||||
*/
|
||||
fun <T> transaction(recoverableFailureTolerance: Int, statement: DatabaseTransaction.() -> T): T {
|
||||
return transaction(defaultIsolationLevel, recoverableFailureTolerance, false, statement)
|
||||
return transaction(defaultIsolationLevel, recoverableFailureTolerance, false, false, statement)
|
||||
}
|
||||
|
||||
private fun <T> inTopLevelTransaction(isolationLevel: TransactionIsolationLevel, recoverableFailureTolerance: Int,
|
||||
@ -292,7 +301,13 @@ class CordaPersistence(
|
||||
|
||||
override fun close() {
|
||||
// DataSource doesn't implement AutoCloseable so we just have to hope that the implementation does so that we can close it
|
||||
(_dataSource as? AutoCloseable)?.close()
|
||||
val mayBeAutoClosableDataSource = _dataSource as? AutoCloseable
|
||||
if(mayBeAutoClosableDataSource != null) {
|
||||
log.info("Closing $mayBeAutoClosableDataSource")
|
||||
mayBeAutoClosableDataSource.close()
|
||||
} else {
|
||||
log.warn("$_dataSource has not been properly closed")
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
|
||||
|
@ -39,9 +39,13 @@ class DatabaseTransaction(
|
||||
}
|
||||
|
||||
// Returns a delegate which overrides certain operations that we do not want CorDapp developers to call.
|
||||
val restrictedEntityManager: RestrictedEntityManager by lazy {
|
||||
val entityManager = session as EntityManager
|
||||
RestrictedEntityManager(entityManager)
|
||||
|
||||
val entityManager: EntityManager get() {
|
||||
// Always retrieve new session ([Session] implements [EntityManager])
|
||||
// Note, this does not replace the top level hibernate session
|
||||
val session = database.entityManagerFactory.withOptions().connection(connection).openSession()
|
||||
session.beginTransaction()
|
||||
return session
|
||||
}
|
||||
|
||||
val session: Session by sessionDelegate
|
||||
@ -73,6 +77,10 @@ class DatabaseTransaction(
|
||||
throw DatabaseTransactionException(it)
|
||||
}
|
||||
if (sessionDelegate.isInitialized()) {
|
||||
// The [sessionDelegate] must be initialised otherwise calling [entityManager] will cause an exception
|
||||
if(session.transaction.rollbackOnly) {
|
||||
throw RolledBackDatabaseSessionException()
|
||||
}
|
||||
hibernateTransaction.commit()
|
||||
}
|
||||
connection.commit()
|
||||
@ -124,4 +132,6 @@ class DatabaseTransaction(
|
||||
/**
|
||||
* Wrapper exception, for any exception registered as [DatabaseTransaction.firstExceptionInDatabaseTransaction].
|
||||
*/
|
||||
class DatabaseTransactionException(override val cause: Throwable): CordaRuntimeException(cause.message, cause)
|
||||
class DatabaseTransactionException(override val cause: Throwable): CordaRuntimeException(cause.message, cause)
|
||||
|
||||
class RolledBackDatabaseSessionException : CordaRuntimeException("Attempted to commit database transaction marked for rollback")
|
@ -5,24 +5,15 @@ import net.corda.core.internal.NamedCacheFactory
|
||||
import net.corda.core.internal.castIfPossible
|
||||
import net.corda.core.schemas.MappedSchema
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.toHexString
|
||||
import net.corda.nodeapi.internal.persistence.factory.CordaSessionFactoryFactory
|
||||
import org.hibernate.SessionFactory
|
||||
import org.hibernate.boot.Metadata
|
||||
import org.hibernate.boot.MetadataBuilder
|
||||
import org.hibernate.boot.MetadataSources
|
||||
import org.hibernate.boot.registry.BootstrapServiceRegistryBuilder
|
||||
import org.hibernate.boot.registry.classloading.internal.ClassLoaderServiceImpl
|
||||
import org.hibernate.boot.registry.classloading.spi.ClassLoaderService
|
||||
import org.hibernate.cfg.Configuration
|
||||
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider
|
||||
import org.hibernate.service.UnknownUnwrapTypeException
|
||||
import org.hibernate.type.AbstractSingleColumnStandardBasicType
|
||||
import org.hibernate.type.MaterializedBlobType
|
||||
import org.hibernate.type.descriptor.java.PrimitiveByteArrayTypeDescriptor
|
||||
import org.hibernate.type.descriptor.sql.BlobTypeDescriptor
|
||||
import org.hibernate.type.descriptor.sql.VarbinaryTypeDescriptor
|
||||
import java.lang.management.ManagementFactory
|
||||
import java.sql.Connection
|
||||
import java.util.ServiceLoader
|
||||
import javax.management.ObjectName
|
||||
import javax.persistence.AttributeConverter
|
||||
|
||||
@ -30,35 +21,38 @@ class HibernateConfiguration(
|
||||
schemas: Set<MappedSchema>,
|
||||
private val databaseConfig: DatabaseConfig,
|
||||
private val attributeConverters: Collection<AttributeConverter<*, *>>,
|
||||
private val jdbcUrl: String,
|
||||
jdbcUrl: String,
|
||||
cacheFactory: NamedCacheFactory,
|
||||
val customClassLoader: ClassLoader? = null
|
||||
) {
|
||||
companion object {
|
||||
private val logger = contextLogger()
|
||||
|
||||
// register custom converters
|
||||
fun buildHibernateMetadata(metadataBuilder: MetadataBuilder, jdbcUrl:String, attributeConverters: Collection<AttributeConverter<*, *>>): Metadata {
|
||||
metadataBuilder.run {
|
||||
attributeConverters.forEach { applyAttributeConverter(it) }
|
||||
// Register a tweaked version of `org.hibernate.type.MaterializedBlobType` that truncates logged messages.
|
||||
// to avoid OOM when large blobs might get logged.
|
||||
applyBasicType(CordaMaterializedBlobType, CordaMaterializedBlobType.name)
|
||||
applyBasicType(CordaWrapperBinaryType, CordaWrapperBinaryType.name)
|
||||
// Will be used in open core
|
||||
fun buildHibernateMetadata(metadataBuilder: MetadataBuilder, jdbcUrl: String, attributeConverters: Collection<AttributeConverter<*, *>>): Metadata {
|
||||
val sff = findSessionFactoryFactory(jdbcUrl, null)
|
||||
return sff.buildHibernateMetadata(metadataBuilder, attributeConverters)
|
||||
}
|
||||
|
||||
// Create a custom type that will map a blob to byteA in postgres and as a normal blob for all other dbms.
|
||||
// This is required for the Checkpoints as a workaround for the issue that postgres has on azure.
|
||||
if (jdbcUrl.contains(":postgresql:", ignoreCase = true)) {
|
||||
applyBasicType(MapBlobToPostgresByteA, MapBlobToPostgresByteA.name)
|
||||
} else {
|
||||
applyBasicType(MapBlobToNormalBlob, MapBlobToNormalBlob.name)
|
||||
}
|
||||
private fun findSessionFactoryFactory(jdbcUrl: String, customClassLoader: ClassLoader?): CordaSessionFactoryFactory {
|
||||
val serviceLoader = if (customClassLoader != null)
|
||||
ServiceLoader.load(CordaSessionFactoryFactory::class.java, customClassLoader)
|
||||
else
|
||||
ServiceLoader.load(CordaSessionFactoryFactory::class.java)
|
||||
|
||||
return build()
|
||||
val sessionFactories = serviceLoader.filter { it.canHandleDatabase(jdbcUrl) }
|
||||
when (sessionFactories.size) {
|
||||
0 -> throw HibernateConfigException("Failed to find a SessionFactoryFactory to handle $jdbcUrl " +
|
||||
"- factories present for ${serviceLoader.map { it.databaseType }}")
|
||||
1 -> return sessionFactories.single()
|
||||
else -> throw HibernateConfigException("Found several SessionFactoryFactory classes to handle $jdbcUrl " +
|
||||
"- classes ${sessionFactories.map { it.javaClass.canonicalName }}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val sessionFactoryFactory = findSessionFactoryFactory(jdbcUrl, customClassLoader)
|
||||
|
||||
private val sessionFactories = cacheFactory.buildNamed<Set<MappedSchema>, SessionFactory>(Caffeine.newBuilder(), "HibernateConfiguration_sessionFactories")
|
||||
|
||||
val sessionFactoryForRegisteredSchemas = schemas.let {
|
||||
@ -70,35 +64,7 @@ class HibernateConfiguration(
|
||||
fun sessionFactoryForSchemas(key: Set<MappedSchema>): SessionFactory = sessionFactories.get(key, ::makeSessionFactoryForSchemas)!!
|
||||
|
||||
private fun makeSessionFactoryForSchemas(schemas: Set<MappedSchema>): SessionFactory {
|
||||
logger.info("Creating session factory for schemas: $schemas")
|
||||
val serviceRegistry = BootstrapServiceRegistryBuilder().build()
|
||||
val metadataSources = MetadataSources(serviceRegistry)
|
||||
|
||||
val hbm2dll: String =
|
||||
if(databaseConfig.initialiseSchema && databaseConfig.initialiseAppSchema == SchemaInitializationType.UPDATE) {
|
||||
"update"
|
||||
} else if((!databaseConfig.initialiseSchema && databaseConfig.initialiseAppSchema == SchemaInitializationType.UPDATE)
|
||||
|| databaseConfig.initialiseAppSchema == SchemaInitializationType.VALIDATE) {
|
||||
"validate"
|
||||
} else {
|
||||
"none"
|
||||
}
|
||||
|
||||
// We set a connection provider as the auto schema generation requires it. The auto schema generation will not
|
||||
// necessarily remain and would likely be replaced by something like Liquibase. For now it is very convenient though.
|
||||
val config = Configuration(metadataSources).setProperty("hibernate.connection.provider_class", NodeDatabaseConnectionProvider::class.java.name)
|
||||
.setProperty("hibernate.format_sql", "true")
|
||||
.setProperty("hibernate.hbm2ddl.auto", hbm2dll)
|
||||
.setProperty("javax.persistence.validation.mode", "none")
|
||||
.setProperty("hibernate.connection.isolation", databaseConfig.transactionIsolationLevel.jdbcValue.toString())
|
||||
|
||||
schemas.forEach { schema ->
|
||||
// TODO: require mechanism to set schemaOptions (databaseSchema, tablePrefix) which are not global to session
|
||||
schema.mappedTypes.forEach { config.addAnnotatedClass(it) }
|
||||
}
|
||||
|
||||
val sessionFactory = buildSessionFactory(config, metadataSources, customClassLoader)
|
||||
logger.info("Created session factory for schemas: $schemas")
|
||||
val sessionFactory = sessionFactoryFactory.makeSessionFactoryForSchemas(databaseConfig, schemas, customClassLoader, attributeConverters)
|
||||
|
||||
// export Hibernate JMX statistics
|
||||
if (databaseConfig.exportHibernateJMXStatistics)
|
||||
@ -123,27 +89,6 @@ class HibernateConfiguration(
|
||||
}
|
||||
}
|
||||
|
||||
private fun buildSessionFactory(config: Configuration, metadataSources: MetadataSources, customClassLoader: ClassLoader?): SessionFactory {
|
||||
config.standardServiceRegistryBuilder.applySettings(config.properties)
|
||||
|
||||
if (customClassLoader != null) {
|
||||
config.standardServiceRegistryBuilder.addService(
|
||||
ClassLoaderService::class.java,
|
||||
ClassLoaderServiceImpl(customClassLoader))
|
||||
}
|
||||
|
||||
@Suppress("DEPRECATION")
|
||||
val metadataBuilder = metadataSources.getMetadataBuilder(config.standardServiceRegistryBuilder.build())
|
||||
val metadata = buildHibernateMetadata(metadataBuilder, jdbcUrl, attributeConverters)
|
||||
return metadata.sessionFactoryBuilder.run {
|
||||
allowOutOfTransactionUpdateOperations(true)
|
||||
applySecondLevelCacheSupport(false)
|
||||
applyQueryCacheSupport(false)
|
||||
enableReleaseResourcesOnCloseEnabled(true)
|
||||
build()
|
||||
}
|
||||
}
|
||||
|
||||
// Supply Hibernate with connections from our underlying Exposed database integration. Only used
|
||||
// during schema creation / update.
|
||||
class NodeDatabaseConnectionProvider : ConnectionProvider {
|
||||
@ -168,55 +113,5 @@ class HibernateConfiguration(
|
||||
override fun isUnwrappableAs(unwrapType: Class<*>?): Boolean = unwrapType == NodeDatabaseConnectionProvider::class.java
|
||||
}
|
||||
|
||||
// A tweaked version of `org.hibernate.type.MaterializedBlobType` that truncates logged messages. Also logs in hex.
|
||||
object CordaMaterializedBlobType : AbstractSingleColumnStandardBasicType<ByteArray>(BlobTypeDescriptor.DEFAULT, CordaPrimitiveByteArrayTypeDescriptor) {
|
||||
override fun getName(): String {
|
||||
return "materialized_blob"
|
||||
}
|
||||
}
|
||||
|
||||
// A tweaked version of `org.hibernate.type.descriptor.java.PrimitiveByteArrayTypeDescriptor` that truncates logged messages.
|
||||
private object CordaPrimitiveByteArrayTypeDescriptor : PrimitiveByteArrayTypeDescriptor() {
|
||||
private const val LOG_SIZE_LIMIT = 1024
|
||||
|
||||
override fun extractLoggableRepresentation(value: ByteArray?): String {
|
||||
return if (value == null) {
|
||||
super.extractLoggableRepresentation(value)
|
||||
} else {
|
||||
if (value.size <= LOG_SIZE_LIMIT) {
|
||||
"[size=${value.size}, value=${value.toHexString()}]"
|
||||
} else {
|
||||
"[size=${value.size}, value=${value.copyOfRange(0, LOG_SIZE_LIMIT).toHexString()}...truncated...]"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A tweaked version of `org.hibernate.type.WrapperBinaryType` that deals with ByteArray (java primitive byte[] type).
|
||||
object CordaWrapperBinaryType : AbstractSingleColumnStandardBasicType<ByteArray>(VarbinaryTypeDescriptor.INSTANCE, PrimitiveByteArrayTypeDescriptor.INSTANCE) {
|
||||
override fun getRegistrationKeys(): Array<String> {
|
||||
return arrayOf(name, "ByteArray", ByteArray::class.java.name)
|
||||
}
|
||||
|
||||
override fun getName(): String {
|
||||
return "corda-wrapper-binary"
|
||||
}
|
||||
}
|
||||
|
||||
// Maps to a byte array on postgres.
|
||||
object MapBlobToPostgresByteA : AbstractSingleColumnStandardBasicType<ByteArray>(VarbinaryTypeDescriptor.INSTANCE, PrimitiveByteArrayTypeDescriptor.INSTANCE) {
|
||||
override fun getRegistrationKeys(): Array<String> {
|
||||
return arrayOf(name, "ByteArray", ByteArray::class.java.name)
|
||||
}
|
||||
|
||||
override fun getName(): String {
|
||||
return "corda-blob"
|
||||
}
|
||||
}
|
||||
|
||||
object MapBlobToNormalBlob : MaterializedBlobType() {
|
||||
override fun getName(): String {
|
||||
return "corda-blob"
|
||||
}
|
||||
}
|
||||
fun getExtraConfiguration(key: String ) = sessionFactoryFactory.getExtraConfiguration(key)
|
||||
}
|
||||
|
@ -0,0 +1,80 @@
|
||||
package net.corda.nodeapi.internal.persistence
|
||||
|
||||
import java.sql.Connection
|
||||
import java.sql.Savepoint
|
||||
import java.util.concurrent.Executor
|
||||
|
||||
/**
|
||||
* A delegate of [Connection] which disallows some operations.
|
||||
*/
|
||||
@Suppress("TooManyFunctions")
|
||||
class RestrictedConnection(private val delegate : Connection) : Connection by delegate {
|
||||
|
||||
override fun abort(executor: Executor?) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun clearWarnings() {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun commit() {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun setSavepoint(): Savepoint? {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun setSavepoint(name : String?): Savepoint? {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun releaseSavepoint(savepoint: Savepoint?) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun rollback() {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun rollback(savepoint: Savepoint?) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun setCatalog(catalog : String?) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun setTransactionIsolation(level: Int) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun setTypeMap(map: MutableMap<String, Class<*>>?) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun setHoldability(holdability: Int) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun setSchema(schema: String?) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun setNetworkTimeout(executor: Executor?, milliseconds: Int) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun setAutoCommit(autoCommit: Boolean) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
|
||||
override fun setReadOnly(readOnly: Boolean) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.jdbcSession.")
|
||||
}
|
||||
}
|
@ -1,19 +1,63 @@
|
||||
package net.corda.nodeapi.internal.persistence
|
||||
|
||||
import javax.persistence.EntityManager
|
||||
import javax.persistence.EntityTransaction
|
||||
import javax.persistence.LockModeType
|
||||
import javax.persistence.metamodel.Metamodel
|
||||
|
||||
/**
|
||||
* A delegate of [EntityManager] which disallows some operations.
|
||||
*/
|
||||
class RestrictedEntityManager(private val delegate: EntityManager) : EntityManager by delegate {
|
||||
|
||||
override fun getTransaction(): EntityTransaction {
|
||||
return RestrictedEntityTransaction(delegate.transaction)
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.withEntityManager.")
|
||||
}
|
||||
|
||||
override fun clear() {
|
||||
override fun <T : Any?> unwrap(cls: Class<T>?): T {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.withEntityManager.")
|
||||
}
|
||||
|
||||
// TODO: Figure out which other methods on EntityManager need to be blocked?
|
||||
override fun getDelegate(): Any {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.withEntityManager.")
|
||||
}
|
||||
|
||||
override fun getMetamodel(): Metamodel? {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.withEntityManager.")
|
||||
}
|
||||
|
||||
override fun joinTransaction() {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.withEntityManager.")
|
||||
}
|
||||
|
||||
override fun lock(entity: Any?, lockMode: LockModeType?) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.withEntityManager.")
|
||||
}
|
||||
|
||||
override fun lock(entity: Any?, lockMode: LockModeType?, properties: MutableMap<String, Any>?) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.withEntityManager.")
|
||||
}
|
||||
|
||||
override fun setProperty(propertyName: String?, value: Any?) {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.withEntityManager.")
|
||||
}
|
||||
}
|
||||
|
||||
class RestrictedEntityTransaction(private val delegate: EntityTransaction) : EntityTransaction by delegate {
|
||||
|
||||
override fun rollback() {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.withEntityManager.")
|
||||
}
|
||||
|
||||
override fun commit() {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.withEntityManager.")
|
||||
}
|
||||
|
||||
override fun begin() {
|
||||
throw UnsupportedOperationException("This method cannot be called via ServiceHub.withEntityManager.")
|
||||
}
|
||||
}
|
@ -0,0 +1,149 @@
|
||||
package net.corda.nodeapi.internal.persistence.factory
|
||||
|
||||
import net.corda.core.schemas.MappedSchema
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.toHexString
|
||||
import net.corda.nodeapi.internal.persistence.DatabaseConfig
|
||||
import net.corda.nodeapi.internal.persistence.HibernateConfiguration
|
||||
import net.corda.nodeapi.internal.persistence.SchemaInitializationType
|
||||
import org.hibernate.SessionFactory
|
||||
import org.hibernate.boot.Metadata
|
||||
import org.hibernate.boot.MetadataBuilder
|
||||
import org.hibernate.boot.MetadataSources
|
||||
import org.hibernate.boot.registry.BootstrapServiceRegistryBuilder
|
||||
import org.hibernate.boot.registry.classloading.internal.ClassLoaderServiceImpl
|
||||
import org.hibernate.boot.registry.classloading.spi.ClassLoaderService
|
||||
import org.hibernate.cfg.Configuration
|
||||
import org.hibernate.type.AbstractSingleColumnStandardBasicType
|
||||
import org.hibernate.type.MaterializedBlobType
|
||||
import org.hibernate.type.descriptor.java.PrimitiveByteArrayTypeDescriptor
|
||||
import org.hibernate.type.descriptor.sql.BlobTypeDescriptor
|
||||
import org.hibernate.type.descriptor.sql.VarbinaryTypeDescriptor
|
||||
import javax.persistence.AttributeConverter
|
||||
|
||||
abstract class BaseSessionFactoryFactory : CordaSessionFactoryFactory {
|
||||
companion object {
|
||||
private val logger = contextLogger()
|
||||
}
|
||||
|
||||
open fun buildHibernateConfig(databaseConfig: DatabaseConfig, metadataSources: MetadataSources): Configuration {
|
||||
val hbm2dll: String =
|
||||
if (databaseConfig.initialiseSchema && databaseConfig.initialiseAppSchema == SchemaInitializationType.UPDATE) {
|
||||
"update"
|
||||
} else if ((!databaseConfig.initialiseSchema && databaseConfig.initialiseAppSchema == SchemaInitializationType.UPDATE)
|
||||
|| databaseConfig.initialiseAppSchema == SchemaInitializationType.VALIDATE) {
|
||||
"validate"
|
||||
} else {
|
||||
"none"
|
||||
}
|
||||
// We set a connection provider as the auto schema generation requires it. The auto schema generation will not
|
||||
// necessarily remain and would likely be replaced by something like Liquibase. For now it is very convenient though.
|
||||
return Configuration(metadataSources).setProperty("hibernate.connection.provider_class", HibernateConfiguration.NodeDatabaseConnectionProvider::class.java.name)
|
||||
.setProperty("hibernate.format_sql", "true")
|
||||
.setProperty("javax.persistence.validation.mode", "none")
|
||||
.setProperty("hibernate.connection.isolation", databaseConfig.transactionIsolationLevel.jdbcValue.toString())
|
||||
.setProperty("hibernate.hbm2ddl.auto", hbm2dll)
|
||||
.setProperty("hibernate.jdbc.time_zone", "UTC")
|
||||
}
|
||||
|
||||
override fun buildHibernateMetadata(metadataBuilder: MetadataBuilder, attributeConverters: Collection<AttributeConverter<*, *>>): Metadata {
|
||||
return metadataBuilder.run {
|
||||
attributeConverters.forEach { applyAttributeConverter(it) }
|
||||
// Register a tweaked version of `org.hibernate.type.MaterializedBlobType` that truncates logged messages.
|
||||
// to avoid OOM when large blobs might get logged.
|
||||
applyBasicType(CordaMaterializedBlobType, CordaMaterializedBlobType.name)
|
||||
applyBasicType(CordaWrapperBinaryType, CordaWrapperBinaryType.name)
|
||||
applyBasicType(MapBlobToNormalBlob, MapBlobToNormalBlob.name)
|
||||
|
||||
build()
|
||||
}
|
||||
}
|
||||
|
||||
fun buildSessionFactory(
|
||||
config: Configuration,
|
||||
metadataSources: MetadataSources,
|
||||
customClassLoader: ClassLoader?,
|
||||
attributeConverters: Collection<AttributeConverter<*, *>>): SessionFactory {
|
||||
config.standardServiceRegistryBuilder.applySettings(config.properties)
|
||||
|
||||
if (customClassLoader != null) {
|
||||
config.standardServiceRegistryBuilder.addService(
|
||||
ClassLoaderService::class.java,
|
||||
ClassLoaderServiceImpl(customClassLoader))
|
||||
}
|
||||
|
||||
@Suppress("DEPRECATION")
|
||||
val metadataBuilder = metadataSources.getMetadataBuilder(config.standardServiceRegistryBuilder.build())
|
||||
val metadata = buildHibernateMetadata(metadataBuilder, attributeConverters)
|
||||
return metadata.sessionFactoryBuilder.run {
|
||||
allowOutOfTransactionUpdateOperations(true)
|
||||
applySecondLevelCacheSupport(false)
|
||||
applyQueryCacheSupport(false)
|
||||
enableReleaseResourcesOnCloseEnabled(true)
|
||||
build()
|
||||
}
|
||||
}
|
||||
|
||||
final override fun makeSessionFactoryForSchemas(
|
||||
databaseConfig: DatabaseConfig,
|
||||
schemas: Set<MappedSchema>,
|
||||
customClassLoader: ClassLoader?,
|
||||
attributeConverters: Collection<AttributeConverter<*, *>>): SessionFactory {
|
||||
logger.info("Creating session factory for schemas: $schemas")
|
||||
val serviceRegistry = BootstrapServiceRegistryBuilder().build()
|
||||
val metadataSources = MetadataSources(serviceRegistry)
|
||||
|
||||
val config = buildHibernateConfig(databaseConfig, metadataSources)
|
||||
schemas.forEach { schema ->
|
||||
schema.mappedTypes.forEach { config.addAnnotatedClass(it) }
|
||||
}
|
||||
val sessionFactory = buildSessionFactory(config, metadataSources, customClassLoader, attributeConverters)
|
||||
logger.info("Created session factory for schemas: $schemas")
|
||||
return sessionFactory
|
||||
}
|
||||
|
||||
override fun getExtraConfiguration(key: String): Any? {
|
||||
return null
|
||||
}
|
||||
|
||||
// A tweaked version of `org.hibernate.type.WrapperBinaryType` that deals with ByteArray (java primitive byte[] type).
|
||||
object CordaWrapperBinaryType : AbstractSingleColumnStandardBasicType<ByteArray>(VarbinaryTypeDescriptor.INSTANCE, PrimitiveByteArrayTypeDescriptor.INSTANCE) {
|
||||
override fun getRegistrationKeys(): Array<String> {
|
||||
return arrayOf(name, "ByteArray", ByteArray::class.java.name)
|
||||
}
|
||||
|
||||
override fun getName(): String {
|
||||
return "corda-wrapper-binary"
|
||||
}
|
||||
}
|
||||
|
||||
object MapBlobToNormalBlob : MaterializedBlobType() {
|
||||
override fun getName(): String {
|
||||
return "corda-blob"
|
||||
}
|
||||
}
|
||||
|
||||
// A tweaked version of `org.hibernate.type.descriptor.java.PrimitiveByteArrayTypeDescriptor` that truncates logged messages.
|
||||
private object CordaPrimitiveByteArrayTypeDescriptor : PrimitiveByteArrayTypeDescriptor() {
|
||||
private const val LOG_SIZE_LIMIT = 1024
|
||||
|
||||
override fun extractLoggableRepresentation(value: ByteArray?): String {
|
||||
return if (value == null) {
|
||||
super.extractLoggableRepresentation(value)
|
||||
} else {
|
||||
if (value.size <= LOG_SIZE_LIMIT) {
|
||||
"[size=${value.size}, value=${value.toHexString()}]"
|
||||
} else {
|
||||
"[size=${value.size}, value=${value.copyOfRange(0, LOG_SIZE_LIMIT).toHexString()}...truncated...]"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A tweaked version of `org.hibernate.type.MaterializedBlobType` that truncates logged messages. Also logs in hex.
|
||||
object CordaMaterializedBlobType : AbstractSingleColumnStandardBasicType<ByteArray>(BlobTypeDescriptor.DEFAULT, CordaPrimitiveByteArrayTypeDescriptor) {
|
||||
override fun getName(): String {
|
||||
return "materialized_blob"
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
package net.corda.nodeapi.internal.persistence.factory
|
||||
|
||||
import net.corda.core.schemas.MappedSchema
|
||||
import net.corda.nodeapi.internal.persistence.DatabaseConfig
|
||||
import org.hibernate.SessionFactory
|
||||
import org.hibernate.boot.Metadata
|
||||
import org.hibernate.boot.MetadataBuilder
|
||||
import javax.persistence.AttributeConverter
|
||||
|
||||
interface CordaSessionFactoryFactory {
|
||||
val databaseType: String
|
||||
fun canHandleDatabase(jdbcUrl: String): Boolean
|
||||
fun makeSessionFactoryForSchemas(
|
||||
databaseConfig: DatabaseConfig,
|
||||
schemas: Set<MappedSchema>,
|
||||
customClassLoader: ClassLoader?,
|
||||
attributeConverters: Collection<AttributeConverter<*, *>>): SessionFactory
|
||||
fun getExtraConfiguration(key: String): Any?
|
||||
fun buildHibernateMetadata(metadataBuilder: MetadataBuilder, attributeConverters: Collection<AttributeConverter<*, *>>): Metadata
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
package net.corda.nodeapi.internal.persistence.factory
|
||||
|
||||
class H2SessionFactoryFactory : BaseSessionFactoryFactory() {
|
||||
override fun canHandleDatabase(jdbcUrl: String): Boolean = jdbcUrl.startsWith("jdbc:h2:")
|
||||
override val databaseType: String = "H2"
|
||||
}
|
@ -0,0 +1,41 @@
|
||||
package net.corda.nodeapi.internal.persistence.factory
|
||||
|
||||
import org.hibernate.boot.Metadata
|
||||
import org.hibernate.boot.MetadataBuilder
|
||||
import org.hibernate.type.AbstractSingleColumnStandardBasicType
|
||||
import org.hibernate.type.descriptor.java.PrimitiveByteArrayTypeDescriptor
|
||||
import org.hibernate.type.descriptor.sql.VarbinaryTypeDescriptor
|
||||
import javax.persistence.AttributeConverter
|
||||
|
||||
class PostgresSessionFactoryFactory : BaseSessionFactoryFactory() {
|
||||
override fun buildHibernateMetadata(metadataBuilder: MetadataBuilder, attributeConverters: Collection<AttributeConverter<*, *>>): Metadata {
|
||||
return metadataBuilder.run {
|
||||
attributeConverters.forEach { applyAttributeConverter(it) }
|
||||
// Register a tweaked version of `org.hibernate.type.MaterializedBlobType` that truncates logged messages.
|
||||
// to avoid OOM when large blobs might get logged.
|
||||
applyBasicType(CordaMaterializedBlobType, CordaMaterializedBlobType.name)
|
||||
applyBasicType(CordaWrapperBinaryType, CordaWrapperBinaryType.name)
|
||||
|
||||
// Create a custom type that will map a blob to byteA in postgres
|
||||
// This is required for the Checkpoints as a workaround for the issue that postgres has on azure.
|
||||
applyBasicType(MapBlobToPostgresByteA, MapBlobToPostgresByteA.name)
|
||||
|
||||
build()
|
||||
}
|
||||
}
|
||||
|
||||
override fun canHandleDatabase(jdbcUrl: String): Boolean = jdbcUrl.contains(":postgresql:")
|
||||
|
||||
// Maps to a byte array on postgres.
|
||||
object MapBlobToPostgresByteA : AbstractSingleColumnStandardBasicType<ByteArray>(VarbinaryTypeDescriptor.INSTANCE, PrimitiveByteArrayTypeDescriptor.INSTANCE) {
|
||||
override fun getRegistrationKeys(): Array<String> {
|
||||
return arrayOf(name, "ByteArray", ByteArray::class.java.name)
|
||||
}
|
||||
|
||||
override fun getName(): String {
|
||||
return "corda-blob"
|
||||
}
|
||||
}
|
||||
|
||||
override val databaseType: String = "PostgreSQL"
|
||||
}
|
@ -45,7 +45,11 @@ internal class ConnectionStateMachine(private val serverMode: Boolean,
|
||||
userName: String?,
|
||||
password: String?) : BaseHandler() {
|
||||
companion object {
|
||||
private const val IDLE_TIMEOUT = 10000
|
||||
private const val CORDA_AMQP_FRAME_SIZE_PROP_NAME = "net.corda.nodeapi.connectionstatemachine.AmqpMaxFrameSize"
|
||||
private const val CORDA_AMQP_IDLE_TIMEOUT_PROP_NAME = "net.corda.nodeapi.connectionstatemachine.AmqpIdleTimeout"
|
||||
|
||||
private val MAX_FRAME_SIZE = Integer.getInteger(CORDA_AMQP_FRAME_SIZE_PROP_NAME, 128 * 1024)
|
||||
private val IDLE_TIMEOUT = Integer.getInteger(CORDA_AMQP_IDLE_TIMEOUT_PROP_NAME, 10 * 1000)
|
||||
private val log = contextLogger()
|
||||
}
|
||||
|
||||
@ -102,6 +106,7 @@ internal class ConnectionStateMachine(private val serverMode: Boolean,
|
||||
transport.context = connection
|
||||
@Suppress("UsePropertyAccessSyntax")
|
||||
transport.setEmitFlowEventOnSend(true)
|
||||
transport.maxFrameSize = MAX_FRAME_SIZE
|
||||
connection.collect(collector)
|
||||
val sasl = transport.sasl()
|
||||
if (userName != null) {
|
||||
@ -224,12 +229,17 @@ internal class ConnectionStateMachine(private val serverMode: Boolean,
|
||||
}
|
||||
|
||||
override fun onTransportClosed(event: Event) {
|
||||
val transport = event.transport
|
||||
logDebugWithMDC { "Transport Closed ${transport.prettyPrint}" }
|
||||
if (transport == this.transport) {
|
||||
doTransportClose(event.transport) { "Transport Closed ${transport.prettyPrint}" }
|
||||
}
|
||||
|
||||
private fun doTransportClose(transport: Transport?, msg: () -> String) {
|
||||
if (transport != null && transport == this.transport && transport.context != null) {
|
||||
logDebugWithMDC(msg)
|
||||
transport.unbind()
|
||||
transport.free()
|
||||
transport.context = null
|
||||
} else {
|
||||
logDebugWithMDC { "Nothing to do in case of: ${msg()}" }
|
||||
}
|
||||
}
|
||||
|
||||
@ -259,6 +269,9 @@ internal class ConnectionStateMachine(private val serverMode: Boolean,
|
||||
val channel = connection?.context as? Channel
|
||||
channel?.writeAndFlush(transport)
|
||||
}
|
||||
} else {
|
||||
logDebugWithMDC { "Transport is already closed: ${transport.prettyPrint}" }
|
||||
doTransportClose(transport) { "Freeing-up resources associated with transport" }
|
||||
}
|
||||
}
|
||||
|
||||
@ -309,13 +322,7 @@ internal class ConnectionStateMachine(private val serverMode: Boolean,
|
||||
// If TRANSPORT_CLOSED event was already processed, the 'transport' in all subsequent events is set to null.
|
||||
// There is, however, a chance of missing TRANSPORT_CLOSED event, e.g. when disconnect occurs before opening remote session.
|
||||
// In such cases we must explicitly cleanup the 'transport' in order to guarantee the delivery of CONNECTION_FINAL event.
|
||||
val transport = event.transport
|
||||
if (transport == this.transport) {
|
||||
logDebugWithMDC { "Missed TRANSPORT_CLOSED: force cleanup ${transport.prettyPrint}" }
|
||||
transport.unbind()
|
||||
transport.free()
|
||||
transport.context = null
|
||||
}
|
||||
doTransportClose(event.transport) { "Missed TRANSPORT_CLOSED in onSessionFinal: force cleanup ${transport.prettyPrint}" }
|
||||
}
|
||||
}
|
||||
|
||||
@ -488,7 +495,9 @@ internal class ConnectionStateMachine(private val serverMode: Boolean,
|
||||
}
|
||||
|
||||
fun transportWriteMessage(msg: SendableMessageImpl) {
|
||||
msg.buf = encodePayloadBytes(msg)
|
||||
val encoded = encodePayloadBytes(msg)
|
||||
msg.release()
|
||||
msg.buf = encoded
|
||||
val messageQueue = messageQueues.getOrPut(msg.topic, { LinkedList() })
|
||||
messageQueue.offer(msg)
|
||||
if (session != null) {
|
||||
|
@ -38,7 +38,9 @@ internal class EventProcessor(private val channel: Channel,
|
||||
userName: String?,
|
||||
password: String?) {
|
||||
companion object {
|
||||
private const val FLOW_WINDOW_SIZE = 10
|
||||
private const val CORDA_AMQP_FLOW_WINDOW_SIZE_PROP_NAME = "net.corda.nodeapi.eventprocessor.FlowWindowSize"
|
||||
|
||||
private val FLOW_WINDOW_SIZE = Integer.getInteger(CORDA_AMQP_FLOW_WINDOW_SIZE_PROP_NAME, 5)
|
||||
private val log = contextLogger()
|
||||
}
|
||||
|
||||
|
@ -11,4 +11,5 @@ interface ApplicationMessage {
|
||||
val destinationLegalName: String
|
||||
val destinationLink: NetworkHostAndPort
|
||||
val applicationProperties: Map<String, Any?>
|
||||
fun release()
|
||||
}
|
@ -2,6 +2,7 @@ package net.corda.nodeapi.internal.protonwrapper.messages.impl
|
||||
|
||||
import io.netty.channel.Channel
|
||||
import net.corda.core.utilities.NetworkHostAndPort
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
|
||||
import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
|
||||
import org.apache.qpid.proton.engine.Delivery
|
||||
@ -10,7 +11,7 @@ import org.apache.qpid.proton.engine.Delivery
|
||||
* An internal packet management class that allows tracking of asynchronous acknowledgements
|
||||
* that in turn send Delivery messages back to the originator.
|
||||
*/
|
||||
internal class ReceivedMessageImpl(override val payload: ByteArray,
|
||||
internal class ReceivedMessageImpl(override var payload: ByteArray,
|
||||
override val topic: String,
|
||||
override val sourceLegalName: String,
|
||||
override val sourceLink: NetworkHostAndPort,
|
||||
@ -19,11 +20,25 @@ internal class ReceivedMessageImpl(override val payload: ByteArray,
|
||||
override val applicationProperties: Map<String, Any?>,
|
||||
private val channel: Channel,
|
||||
private val delivery: Delivery) : ReceivedMessage {
|
||||
companion object {
|
||||
private val emptyPayload = ByteArray(0)
|
||||
private val logger = contextLogger()
|
||||
}
|
||||
|
||||
data class MessageCompleter(val status: MessageStatus, val delivery: Delivery)
|
||||
|
||||
override fun release() {
|
||||
payload = emptyPayload
|
||||
}
|
||||
|
||||
override fun complete(accepted: Boolean) {
|
||||
release()
|
||||
val status = if (accepted) MessageStatus.Acknowledged else MessageStatus.Rejected
|
||||
channel.writeAndFlush(MessageCompleter(status, delivery))
|
||||
if (channel.isActive) {
|
||||
channel.writeAndFlush(MessageCompleter(status, delivery))
|
||||
} else {
|
||||
logger.info("Not writing $status as $channel is not active")
|
||||
}
|
||||
}
|
||||
|
||||
override fun toString(): String = "Received ${String(payload)} $topic"
|
||||
|
@ -11,11 +11,15 @@ import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
|
||||
* An internal packet management class that allows handling of the encoded buffers and
|
||||
* allows registration of an acknowledgement handler when the remote receiver confirms durable storage.
|
||||
*/
|
||||
internal class SendableMessageImpl(override val payload: ByteArray,
|
||||
internal class SendableMessageImpl(override var payload: ByteArray,
|
||||
override val topic: String,
|
||||
override val destinationLegalName: String,
|
||||
override val destinationLink: NetworkHostAndPort,
|
||||
override val applicationProperties: Map<String, Any?>) : SendableMessage {
|
||||
companion object {
|
||||
private val emptyPayload = ByteArray(0)
|
||||
}
|
||||
|
||||
var buf: ByteBuf? = null
|
||||
@Volatile
|
||||
var status: MessageStatus = MessageStatus.Unsent
|
||||
@ -23,12 +27,14 @@ internal class SendableMessageImpl(override val payload: ByteArray,
|
||||
private val _onComplete = openFuture<MessageStatus>()
|
||||
override val onComplete: CordaFuture<MessageStatus> get() = _onComplete
|
||||
|
||||
fun release() {
|
||||
override fun release() {
|
||||
payload = emptyPayload
|
||||
buf?.release()
|
||||
buf = null
|
||||
}
|
||||
|
||||
fun doComplete(status: MessageStatus) {
|
||||
release()
|
||||
this.status = status
|
||||
_onComplete.set(status)
|
||||
}
|
||||
|
@ -5,11 +5,16 @@ import io.netty.channel.ChannelDuplexHandler
|
||||
import io.netty.channel.ChannelHandlerContext
|
||||
import io.netty.channel.ChannelPromise
|
||||
import io.netty.channel.socket.SocketChannel
|
||||
import io.netty.handler.proxy.ProxyConnectException
|
||||
import io.netty.handler.proxy.ProxyConnectionEvent
|
||||
import io.netty.handler.ssl.SniCompletionEvent
|
||||
import io.netty.handler.ssl.SslHandler
|
||||
import io.netty.handler.ssl.SslHandshakeCompletionEvent
|
||||
import io.netty.util.ReferenceCountUtil
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.trace
|
||||
import net.corda.nodeapi.internal.ArtemisConstants.MESSAGE_ID_KEY
|
||||
import net.corda.nodeapi.internal.crypto.x509
|
||||
import net.corda.nodeapi.internal.protonwrapper.engine.EventProcessor
|
||||
import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
|
||||
@ -23,6 +28,8 @@ import org.slf4j.MDC
|
||||
import java.net.InetSocketAddress
|
||||
import java.nio.channels.ClosedChannelException
|
||||
import java.security.cert.X509Certificate
|
||||
import javax.net.ssl.ExtendedSSLSession
|
||||
import javax.net.ssl.SNIHostName
|
||||
import javax.net.ssl.SSLException
|
||||
|
||||
/**
|
||||
@ -30,29 +37,35 @@ import javax.net.ssl.SSLException
|
||||
* It also add some extra checks to the SSL handshake to support our non-standard certificate checks of legal identity.
|
||||
* When a valid SSL connections is made then it initialises a proton-j engine instance to handle the protocol layer.
|
||||
*/
|
||||
@Suppress("TooManyFunctions")
|
||||
internal class AMQPChannelHandler(private val serverMode: Boolean,
|
||||
private val allowedRemoteLegalNames: Set<CordaX500Name>?,
|
||||
private val keyManagerFactoriesMap: Map<String, CertHoldingKeyManagerFactoryWrapper>,
|
||||
private val userName: String?,
|
||||
private val password: String?,
|
||||
private val trace: Boolean,
|
||||
private val onOpen: (Pair<SocketChannel, ConnectionChange>) -> Unit,
|
||||
private val onClose: (Pair<SocketChannel, ConnectionChange>) -> Unit,
|
||||
private val suppressLogs: Boolean,
|
||||
private val onOpen: (SocketChannel, ConnectionChange) -> Unit,
|
||||
private val onClose: (SocketChannel, ConnectionChange) -> Unit,
|
||||
private val onReceive: (ReceivedMessage) -> Unit) : ChannelDuplexHandler() {
|
||||
companion object {
|
||||
private val log = contextLogger()
|
||||
const val PROXY_LOGGER_NAME = "preProxyLogger"
|
||||
}
|
||||
|
||||
private lateinit var remoteAddress: InetSocketAddress
|
||||
private var localCert: X509Certificate? = null
|
||||
private var remoteCert: X509Certificate? = null
|
||||
private var eventProcessor: EventProcessor? = null
|
||||
private var suppressClose: Boolean = false
|
||||
private var badCert: Boolean = false
|
||||
private var localCert: X509Certificate? = null
|
||||
private var requestedServerName: String? = null
|
||||
|
||||
private fun withMDC(block: () -> Unit) {
|
||||
val oldMDC = MDC.getCopyOfContextMap() ?: emptyMap<String, String>()
|
||||
try {
|
||||
MDC.put("serverMode", serverMode.toString())
|
||||
MDC.put("remoteAddress", remoteAddress.toString())
|
||||
MDC.put("remoteAddress", if (::remoteAddress.isInitialized) remoteAddress.toString() else null)
|
||||
MDC.put("localCert", localCert?.subjectDN?.toString())
|
||||
MDC.put("remoteCert", remoteCert?.subjectDN?.toString())
|
||||
MDC.put("allowedRemoteLegalNames", allowedRemoteLegalNames?.joinToString(separator = ";") { it.toString() })
|
||||
@ -62,39 +75,50 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
|
||||
}
|
||||
}
|
||||
|
||||
private fun logDebugWithMDC(msg: () -> String) {
|
||||
if (log.isDebugEnabled) {
|
||||
withMDC { log.debug(msg()) }
|
||||
private fun logDebugWithMDC(msgFn: () -> String) {
|
||||
if (!suppressLogs) {
|
||||
if (log.isDebugEnabled) {
|
||||
withMDC { log.debug(msgFn()) }
|
||||
}
|
||||
} else {
|
||||
withMDC { log.trace(msgFn) }
|
||||
}
|
||||
}
|
||||
|
||||
private fun logInfoWithMDC(msg: String) = withMDC { log.info(msg) }
|
||||
private fun logInfoWithMDC(msgFn: () -> String) {
|
||||
if (!suppressLogs) {
|
||||
if (log.isInfoEnabled) {
|
||||
withMDC { log.info(msgFn()) }
|
||||
}
|
||||
} else {
|
||||
withMDC { log.trace(msgFn) }
|
||||
}
|
||||
}
|
||||
|
||||
private fun logWarnWithMDC(msg: String) = withMDC { log.warn(msg) }
|
||||
|
||||
private fun logErrorWithMDC(msg: String, ex: Throwable? = null) = withMDC { log.error(msg, ex) }
|
||||
private fun logWarnWithMDC(msg: String) = withMDC { if (!suppressLogs) log.warn(msg) else log.trace { msg } }
|
||||
|
||||
private fun logErrorWithMDC(msg: String, ex: Throwable? = null) = withMDC { if (!suppressLogs) log.error(msg, ex) else log.trace(msg, ex) }
|
||||
|
||||
override fun channelActive(ctx: ChannelHandlerContext) {
|
||||
val ch = ctx.channel()
|
||||
remoteAddress = ch.remoteAddress() as InetSocketAddress
|
||||
val localAddress = ch.localAddress() as InetSocketAddress
|
||||
logInfoWithMDC("New client connection ${ch.id()} from $remoteAddress to $localAddress")
|
||||
logInfoWithMDC { "New client connection ${ch.id()} from $remoteAddress to $localAddress" }
|
||||
}
|
||||
|
||||
private fun createAMQPEngine(ctx: ChannelHandlerContext) {
|
||||
val ch = ctx.channel()
|
||||
eventProcessor = EventProcessor(ch, serverMode, localCert!!.subjectX500Principal.toString(), remoteCert!!.subjectX500Principal.toString(), userName, password)
|
||||
val connection = eventProcessor!!.connection
|
||||
val transport = connection.transport as ProtonJTransport
|
||||
if (trace) {
|
||||
val connection = eventProcessor!!.connection
|
||||
val transport = connection.transport as ProtonJTransport
|
||||
transport.protocolTracer = object : ProtocolTracer {
|
||||
override fun sentFrame(transportFrame: TransportFrame) {
|
||||
logInfoWithMDC("${transportFrame.body}")
|
||||
logInfoWithMDC { "${transportFrame.body}" }
|
||||
}
|
||||
|
||||
override fun receivedFrame(transportFrame: TransportFrame) {
|
||||
logInfoWithMDC("${transportFrame.body}")
|
||||
logInfoWithMDC { "${transportFrame.body}" }
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -104,51 +128,60 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
|
||||
|
||||
override fun channelInactive(ctx: ChannelHandlerContext) {
|
||||
val ch = ctx.channel()
|
||||
logInfoWithMDC("Closed client connection ${ch.id()} from $remoteAddress to ${ch.localAddress()}")
|
||||
onClose(Pair(ch as SocketChannel, ConnectionChange(remoteAddress, remoteCert, false, badCert)))
|
||||
logInfoWithMDC { "Closed client connection ${ch.id()} from $remoteAddress to ${ch.localAddress()}" }
|
||||
if (!suppressClose) {
|
||||
onClose(ch as SocketChannel, ConnectionChange(remoteAddress, remoteCert, false, badCert))
|
||||
}
|
||||
eventProcessor?.close()
|
||||
ctx.fireChannelInactive()
|
||||
}
|
||||
|
||||
override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) {
|
||||
if (evt is SslHandshakeCompletionEvent) {
|
||||
if (evt.isSuccess) {
|
||||
val sslHandler = ctx.pipeline().get(SslHandler::class.java)
|
||||
localCert = sslHandler.engine().session.localCertificates[0].x509
|
||||
remoteCert = sslHandler.engine().session.peerCertificates[0].x509
|
||||
val remoteX500Name = try {
|
||||
CordaX500Name.build(remoteCert!!.subjectX500Principal)
|
||||
} catch (ex: IllegalArgumentException) {
|
||||
badCert = true
|
||||
logErrorWithMDC("Certificate subject not a valid CordaX500Name", ex)
|
||||
ctx.close()
|
||||
return
|
||||
when (evt) {
|
||||
is ProxyConnectionEvent -> {
|
||||
if (trace) {
|
||||
log.info("ProxyConnectionEvent received: $evt")
|
||||
try {
|
||||
ctx.pipeline().remove(PROXY_LOGGER_NAME)
|
||||
} catch (ex: NoSuchElementException) {
|
||||
// ignore
|
||||
}
|
||||
}
|
||||
if (allowedRemoteLegalNames != null && remoteX500Name !in allowedRemoteLegalNames) {
|
||||
badCert = true
|
||||
logErrorWithMDC("Provided certificate subject $remoteX500Name not in expected set $allowedRemoteLegalNames")
|
||||
ctx.close()
|
||||
return
|
||||
}
|
||||
logInfoWithMDC("Handshake completed with subject: $remoteX500Name")
|
||||
createAMQPEngine(ctx)
|
||||
onOpen(Pair(ctx.channel() as SocketChannel, ConnectionChange(remoteAddress, remoteCert, true, false)))
|
||||
} else {
|
||||
val cause = evt.cause()
|
||||
// This happens when the peer node is closed during SSL establishment.
|
||||
if (cause is ClosedChannelException) {
|
||||
logWarnWithMDC("SSL Handshake closed early.")
|
||||
} else if (cause is SSLException && cause.message == "handshake timed out") { // Sadly the exception thrown by Netty wrapper requires that we check the message.
|
||||
logWarnWithMDC("SSL Handshake timed out")
|
||||
} else {
|
||||
badCert = true
|
||||
}
|
||||
logErrorWithMDC("Handshake failure ${evt.cause().message}")
|
||||
if (log.isTraceEnabled) {
|
||||
withMDC { log.trace("Handshake failure", evt.cause()) }
|
||||
}
|
||||
ctx.close()
|
||||
// update address to the real target address
|
||||
remoteAddress = evt.destinationAddress()
|
||||
}
|
||||
is SniCompletionEvent -> {
|
||||
if (evt.isSuccess) {
|
||||
// The SniCompletionEvent is fired up before context is switched (after SslHandshakeCompletionEvent)
|
||||
// so we save the requested server name now to be able log it once the handshake is completed successfully
|
||||
// Note: this event is only triggered when using OpenSSL.
|
||||
requestedServerName = evt.hostname()
|
||||
logInfoWithMDC { "SNI completion success." }
|
||||
} else {
|
||||
logErrorWithMDC("SNI completion failure: ${evt.cause().message}")
|
||||
}
|
||||
}
|
||||
is SslHandshakeCompletionEvent -> {
|
||||
if (evt.isSuccess) {
|
||||
handleSuccessfulHandshake(ctx)
|
||||
} else {
|
||||
handleFailedHandshake(ctx, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun SslHandler.getRequestedServerName(): String? {
|
||||
return if (serverMode) {
|
||||
val session = engine().session
|
||||
when (session) {
|
||||
// Server name can be obtained from SSL session when using JavaSSL.
|
||||
is ExtendedSSLSession -> (session.requestedServerNames.firstOrNull() as? SNIHostName)?.asciiName
|
||||
// For Open SSL server name is obtained from SniCompletionEvent
|
||||
else -> requestedServerName
|
||||
}
|
||||
} else {
|
||||
(engine().sslParameters?.serverNames?.firstOrNull() as? SNIHostName)?.asciiName
|
||||
}
|
||||
}
|
||||
|
||||
@ -158,6 +191,10 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
|
||||
if (log.isTraceEnabled) {
|
||||
withMDC { log.trace("Pipeline uncaught exception", cause) }
|
||||
}
|
||||
if (cause is ProxyConnectException) {
|
||||
log.warn("Proxy connection failed ${cause.message}")
|
||||
suppressClose = true // The pipeline gets marked as active on connection to the proxy rather than to the target, which causes excess close events
|
||||
}
|
||||
ctx.close()
|
||||
}
|
||||
|
||||
@ -176,27 +213,27 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
|
||||
try {
|
||||
try {
|
||||
when (msg) {
|
||||
// Transfers application packet into the AMQP engine.
|
||||
// Transfers application packet into the AMQP engine.
|
||||
is SendableMessageImpl -> {
|
||||
val inetAddress = InetSocketAddress(msg.destinationLink.host, msg.destinationLink.port)
|
||||
logDebugWithMDC { "Message for endpoint $inetAddress , expected $remoteAddress "}
|
||||
logDebugWithMDC { "Message for endpoint $inetAddress , expected $remoteAddress " }
|
||||
|
||||
require(CordaX500Name.parse(msg.destinationLegalName) == CordaX500Name.build(remoteCert!!.subjectX500Principal)) {
|
||||
"Message for incorrect legal identity ${msg.destinationLegalName} expected ${remoteCert!!.subjectX500Principal}"
|
||||
}
|
||||
logDebugWithMDC { "channel write ${msg.applicationProperties["_AMQ_DUPL_ID"]}" }
|
||||
logDebugWithMDC { "channel write ${msg.applicationProperties[MESSAGE_ID_KEY]}" }
|
||||
eventProcessor!!.transportWriteMessage(msg)
|
||||
}
|
||||
// A received AMQP packet has been completed and this self-posted packet will be signalled out to the
|
||||
// external application.
|
||||
// A received AMQP packet has been completed and this self-posted packet will be signalled out to the
|
||||
// external application.
|
||||
is ReceivedMessage -> {
|
||||
onReceive(msg)
|
||||
}
|
||||
// A general self-posted event that triggers creation of AMQP frames when required.
|
||||
// A general self-posted event that triggers creation of AMQP frames when required.
|
||||
is Transport -> {
|
||||
eventProcessor!!.transportProcessOutput(ctx)
|
||||
}
|
||||
// A self-posted event that forwards status updates for delivered packets to the application.
|
||||
// A self-posted event that forwards status updates for delivered packets to the application.
|
||||
is ReceivedMessageImpl.MessageCompleter -> {
|
||||
eventProcessor!!.complete(msg)
|
||||
}
|
||||
@ -210,4 +247,67 @@ internal class AMQPChannelHandler(private val serverMode: Boolean,
|
||||
}
|
||||
eventProcessor!!.processEventsAsync()
|
||||
}
|
||||
|
||||
private fun handleSuccessfulHandshake(ctx: ChannelHandlerContext) {
|
||||
val sslHandler = ctx.pipeline().get(SslHandler::class.java)
|
||||
val sslSession = sslHandler.engine().session
|
||||
// Depending on what matching method is used, getting the local certificate is done by selecting the
|
||||
// appropriate keyManagerFactory
|
||||
val keyManagerFactory = requestedServerName?.let {
|
||||
keyManagerFactoriesMap[it]
|
||||
} ?: keyManagerFactoriesMap.values.single()
|
||||
|
||||
localCert = keyManagerFactory.getCurrentCertChain()?.first()
|
||||
|
||||
if (localCert == null) {
|
||||
log.error("SSL KeyManagerFactory failed to provide a local cert")
|
||||
ctx.close()
|
||||
return
|
||||
}
|
||||
if (sslSession.peerCertificates == null || sslSession.peerCertificates.isEmpty()) {
|
||||
log.error("No peer certificates")
|
||||
ctx.close()
|
||||
return
|
||||
}
|
||||
remoteCert = sslHandler.engine().session.peerCertificates.first().x509
|
||||
val remoteX500Name = try {
|
||||
CordaX500Name.build(remoteCert!!.subjectX500Principal)
|
||||
} catch (ex: IllegalArgumentException) {
|
||||
badCert = true
|
||||
logErrorWithMDC("Certificate subject not a valid CordaX500Name", ex)
|
||||
ctx.close()
|
||||
return
|
||||
}
|
||||
if (allowedRemoteLegalNames != null && remoteX500Name !in allowedRemoteLegalNames) {
|
||||
badCert = true
|
||||
logErrorWithMDC("Provided certificate subject $remoteX500Name not in expected set $allowedRemoteLegalNames")
|
||||
ctx.close()
|
||||
return
|
||||
}
|
||||
|
||||
logInfoWithMDC { "Handshake completed with subject: $remoteX500Name, requested server name: ${sslHandler.getRequestedServerName()}." }
|
||||
createAMQPEngine(ctx)
|
||||
onOpen(ctx.channel() as SocketChannel, ConnectionChange(remoteAddress, remoteCert, connected = true, badCert = false))
|
||||
}
|
||||
|
||||
private fun handleFailedHandshake(ctx: ChannelHandlerContext, evt: SslHandshakeCompletionEvent) {
|
||||
val cause = evt.cause()
|
||||
// This happens when the peer node is closed during SSL establishment.
|
||||
when {
|
||||
cause is ClosedChannelException -> logWarnWithMDC("SSL Handshake closed early.")
|
||||
// Sadly the exception thrown by Netty wrapper requires that we check the message.
|
||||
cause is SSLException && cause.message == "handshake timed out" -> logWarnWithMDC("SSL Handshake timed out")
|
||||
cause is SSLException && (cause.message?.contains("close_notify") == true)
|
||||
-> logWarnWithMDC("Received close_notify during handshake")
|
||||
// io.netty.handler.ssl.SslHandler.setHandshakeFailureTransportFailure()
|
||||
cause is SSLException && (cause.message?.contains("writing TLS control frames") == true) -> logWarnWithMDC(cause.message!!)
|
||||
|
||||
else -> badCert = true
|
||||
}
|
||||
logWarnWithMDC("Handshake failure: ${evt.cause().message}")
|
||||
if (log.isTraceEnabled) {
|
||||
withMDC { log.trace("Handshake failure", evt.cause()) }
|
||||
}
|
||||
ctx.close()
|
||||
}
|
||||
}
|
@ -7,24 +7,45 @@ import io.netty.channel.socket.SocketChannel
|
||||
import io.netty.channel.socket.nio.NioSocketChannel
|
||||
import io.netty.handler.logging.LogLevel
|
||||
import io.netty.handler.logging.LoggingHandler
|
||||
import io.netty.handler.proxy.HttpProxyHandler
|
||||
import io.netty.handler.proxy.Socks4ProxyHandler
|
||||
import io.netty.handler.proxy.Socks5ProxyHandler
|
||||
import io.netty.resolver.NoopAddressResolverGroup
|
||||
import io.netty.util.internal.logging.InternalLoggerFactory
|
||||
import io.netty.util.internal.logging.Slf4JLoggerFactory
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.utilities.NetworkHostAndPort
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.debug
|
||||
import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
|
||||
import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
|
||||
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPChannelHandler.Companion.PROXY_LOGGER_NAME
|
||||
import net.corda.nodeapi.internal.requireMessageSize
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
import java.lang.Long.min
|
||||
import java.net.InetSocketAddress
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
import javax.net.ssl.KeyManagerFactory
|
||||
import javax.net.ssl.TrustManagerFactory
|
||||
import kotlin.concurrent.withLock
|
||||
|
||||
enum class ProxyVersion {
|
||||
SOCKS4,
|
||||
SOCKS5,
|
||||
HTTP
|
||||
}
|
||||
|
||||
data class ProxyConfig(val version: ProxyVersion, val proxyAddress: NetworkHostAndPort, val userName: String? = null, val password: String? = null, val proxyTimeoutMS: Long? = null) {
|
||||
init {
|
||||
if (version == ProxyVersion.SOCKS4) {
|
||||
require(password == null) { "SOCKS4 does not support a password" }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The AMQPClient creates a connection initiator that will try to connect in a round-robin fashion
|
||||
* to the first open SSL socket. It will keep retrying until it is stopped.
|
||||
@ -42,15 +63,18 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
|
||||
}
|
||||
|
||||
val log = contextLogger()
|
||||
const val MIN_RETRY_INTERVAL = 1000L
|
||||
const val MAX_RETRY_INTERVAL = 60000L
|
||||
const val BACKOFF_MULTIPLIER = 2L
|
||||
const val NUM_CLIENT_THREADS = 2
|
||||
|
||||
private const val CORDA_AMQP_NUM_CLIENT_THREAD_PROP_NAME = "net.corda.nodeapi.amqpclient.NumClientThread"
|
||||
|
||||
private const val MIN_RETRY_INTERVAL = 1000L
|
||||
private const val MAX_RETRY_INTERVAL = 60000L
|
||||
private const val BACKOFF_MULTIPLIER = 2L
|
||||
private val NUM_CLIENT_THREADS = Integer.getInteger(CORDA_AMQP_NUM_CLIENT_THREAD_PROP_NAME, 2)
|
||||
}
|
||||
|
||||
private val lock = ReentrantLock()
|
||||
@Volatile
|
||||
private var stopping: Boolean = false
|
||||
private var started: Boolean = false
|
||||
private var workerGroup: EventLoopGroup? = null
|
||||
@Volatile
|
||||
private var clientChannel: Channel? = null
|
||||
@ -59,6 +83,13 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
|
||||
private var currentTarget: NetworkHostAndPort = targets.first()
|
||||
private var retryInterval = MIN_RETRY_INTERVAL
|
||||
private val badCertTargets = mutableSetOf<NetworkHostAndPort>()
|
||||
@Volatile
|
||||
private var amqpActive = false
|
||||
@Volatile
|
||||
private var amqpChannelHandler: ChannelHandler? = null
|
||||
|
||||
val localAddressString: String
|
||||
get() = clientChannel?.localAddress()?.toString() ?: "<unknownLocalAddress>"
|
||||
|
||||
private fun nextTarget() {
|
||||
val origIndex = targetIndex
|
||||
@ -80,29 +111,31 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
|
||||
|
||||
private val connectListener = object : ChannelFutureListener {
|
||||
override fun operationComplete(future: ChannelFuture) {
|
||||
amqpActive = false
|
||||
if (!future.isSuccess) {
|
||||
log.info("Failed to connect to $currentTarget")
|
||||
|
||||
if (!stopping) {
|
||||
if (started) {
|
||||
workerGroup?.schedule({
|
||||
nextTarget()
|
||||
restart()
|
||||
}, retryInterval, TimeUnit.MILLISECONDS)
|
||||
}
|
||||
} else {
|
||||
log.info("Connected to $currentTarget")
|
||||
// Connection established successfully
|
||||
clientChannel = future.channel()
|
||||
clientChannel?.closeFuture()?.addListener(closeListener)
|
||||
log.info("Connected to $currentTarget, Local address: $localAddressString")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val closeListener = ChannelFutureListener { future ->
|
||||
log.info("Disconnected from $currentTarget")
|
||||
log.info("Disconnected from $currentTarget, Local address: $localAddressString")
|
||||
future.channel()?.disconnect()
|
||||
clientChannel = null
|
||||
if (!stopping) {
|
||||
if (started && !amqpActive) {
|
||||
log.debug { "Scheduling restart of $currentTarget (AMQP inactive)" }
|
||||
workerGroup?.schedule({
|
||||
nextTarget()
|
||||
restart()
|
||||
@ -114,42 +147,110 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
|
||||
private val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
|
||||
private val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
|
||||
private val conf = parent.configuration
|
||||
@Volatile
|
||||
private lateinit var amqpChannelHandler: AMQPChannelHandler
|
||||
|
||||
init {
|
||||
keyManagerFactory.init(conf.keyStore)
|
||||
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, conf.crlCheckSoftFail))
|
||||
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, conf.revocationConfig))
|
||||
}
|
||||
|
||||
@Suppress("ComplexMethod")
|
||||
override fun initChannel(ch: SocketChannel) {
|
||||
val pipeline = ch.pipeline()
|
||||
val proxyConfig = conf.proxyConfig
|
||||
if (proxyConfig != null) {
|
||||
if (conf.trace) pipeline.addLast(PROXY_LOGGER_NAME, LoggingHandler(LogLevel.INFO))
|
||||
val proxyAddress = InetSocketAddress(proxyConfig.proxyAddress.host, proxyConfig.proxyAddress.port)
|
||||
val proxy = when (conf.proxyConfig!!.version) {
|
||||
ProxyVersion.SOCKS4 -> {
|
||||
Socks4ProxyHandler(proxyAddress, proxyConfig.userName)
|
||||
}
|
||||
ProxyVersion.SOCKS5 -> {
|
||||
Socks5ProxyHandler(proxyAddress, proxyConfig.userName, proxyConfig.password)
|
||||
}
|
||||
ProxyVersion.HTTP -> {
|
||||
val httpProxyHandler = if(proxyConfig.userName == null || proxyConfig.password == null) {
|
||||
HttpProxyHandler(proxyAddress)
|
||||
} else {
|
||||
HttpProxyHandler(proxyAddress, proxyConfig.userName, proxyConfig.password)
|
||||
}
|
||||
//httpProxyHandler.setConnectTimeoutMillis(3600000) // 1hr for debugging purposes
|
||||
httpProxyHandler
|
||||
}
|
||||
}
|
||||
val proxyTimeout = proxyConfig.proxyTimeoutMS
|
||||
if (proxyTimeout != null) {
|
||||
proxy.setConnectTimeoutMillis(proxyTimeout)
|
||||
}
|
||||
pipeline.addLast("Proxy", proxy)
|
||||
proxy.connectFuture().addListener {
|
||||
if (!it.isSuccess) {
|
||||
ch.disconnect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val wrappedKeyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, parent.configuration)
|
||||
val target = parent.currentTarget
|
||||
val handler = createClientSslHelper(target, parent.allowedRemoteLegalNames, keyManagerFactory, trustManagerFactory)
|
||||
val handler = if (parent.configuration.useOpenSsl) {
|
||||
createClientOpenSslHandler(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory, ch.alloc())
|
||||
} else {
|
||||
createClientSslHelper(target, parent.allowedRemoteLegalNames, wrappedKeyManagerFactory, trustManagerFactory)
|
||||
}
|
||||
handler.handshakeTimeoutMillis = conf.sslHandshakeTimeout
|
||||
pipeline.addLast("sslHandler", handler)
|
||||
if (conf.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO))
|
||||
pipeline.addLast(AMQPChannelHandler(false,
|
||||
amqpChannelHandler = AMQPChannelHandler(false,
|
||||
parent.allowedRemoteLegalNames,
|
||||
// Single entry, key can be anything.
|
||||
mapOf(DEFAULT to wrappedKeyManagerFactory),
|
||||
conf.userName,
|
||||
conf.password,
|
||||
conf.trace,
|
||||
{
|
||||
parent.retryInterval = MIN_RETRY_INTERVAL // reset to fast reconnect if we connect properly
|
||||
parent._onConnection.onNext(it.second)
|
||||
},
|
||||
{
|
||||
parent._onConnection.onNext(it.second)
|
||||
if (it.second.badCert) {
|
||||
log.error("Blocking future connection attempts to $target due to bad certificate on endpoint")
|
||||
parent.badCertTargets += target
|
||||
false,
|
||||
onOpen = { _, change ->
|
||||
parent.run {
|
||||
amqpActive = true
|
||||
retryInterval = MIN_RETRY_INTERVAL // reset to fast reconnect if we connect properly
|
||||
_onConnection.onNext(change)
|
||||
}
|
||||
},
|
||||
{ rcv -> parent._onReceive.onNext(rcv) }))
|
||||
onClose = { _, change ->
|
||||
if (parent.amqpChannelHandler == amqpChannelHandler) {
|
||||
parent.run {
|
||||
_onConnection.onNext(change)
|
||||
if (change.badCert) {
|
||||
log.error("Blocking future connection attempts to $target due to bad certificate on endpoint")
|
||||
badCertTargets += target
|
||||
}
|
||||
|
||||
if (started && amqpActive) {
|
||||
log.debug { "Scheduling restart of $currentTarget (AMQP active)" }
|
||||
workerGroup?.schedule({
|
||||
nextTarget()
|
||||
restart()
|
||||
}, retryInterval, TimeUnit.MILLISECONDS)
|
||||
}
|
||||
amqpActive = false
|
||||
}
|
||||
}
|
||||
},
|
||||
onReceive = { rcv -> parent._onReceive.onNext(rcv) })
|
||||
parent.amqpChannelHandler = amqpChannelHandler
|
||||
pipeline.addLast(amqpChannelHandler)
|
||||
}
|
||||
}
|
||||
|
||||
fun start() {
|
||||
lock.withLock {
|
||||
log.info("connect to: $currentTarget")
|
||||
if (started) {
|
||||
log.info("Already connected to: $currentTarget so returning")
|
||||
return
|
||||
}
|
||||
log.info("Connect to: $currentTarget")
|
||||
workerGroup = sharedThreadPool ?: NioEventLoopGroup(NUM_CLIENT_THREADS)
|
||||
started = true
|
||||
restart()
|
||||
}
|
||||
}
|
||||
@ -161,6 +262,10 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
|
||||
val bootstrap = Bootstrap()
|
||||
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
|
||||
bootstrap.group(workerGroup).channel(NioSocketChannel::class.java).handler(ClientChannelInitializer(this))
|
||||
// Delegate DNS Resolution to the proxy side, if we are using proxy.
|
||||
if (configuration.proxyConfig != null) {
|
||||
bootstrap.resolver(NoopAddressResolverGroup.INSTANCE)
|
||||
}
|
||||
currentTarget = targets[targetIndex]
|
||||
val clientFuture = bootstrap.connect(currentTarget.host, currentTarget.port)
|
||||
clientFuture.addListener(connectListener)
|
||||
@ -168,21 +273,17 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
|
||||
|
||||
fun stop() {
|
||||
lock.withLock {
|
||||
log.info("disconnect from: $currentTarget")
|
||||
stopping = true
|
||||
try {
|
||||
if (sharedThreadPool == null) {
|
||||
workerGroup?.shutdownGracefully()
|
||||
workerGroup?.terminationFuture()?.sync()
|
||||
} else {
|
||||
clientChannel?.close()?.sync()
|
||||
}
|
||||
clientChannel = null
|
||||
workerGroup = null
|
||||
} finally {
|
||||
stopping = false
|
||||
log.info("Stopping connection to: $currentTarget, Local address: $localAddressString")
|
||||
started = false
|
||||
if (sharedThreadPool == null) {
|
||||
workerGroup?.shutdownGracefully()
|
||||
workerGroup?.terminationFuture()?.sync()
|
||||
} else {
|
||||
clientChannel?.close()?.sync()
|
||||
}
|
||||
log.info("stopped connection to $currentTarget")
|
||||
clientChannel = null
|
||||
workerGroup = null
|
||||
log.info("Stopped connection to $currentTarget")
|
||||
}
|
||||
}
|
||||
|
||||
@ -191,7 +292,7 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
|
||||
val connected: Boolean
|
||||
get() {
|
||||
val channel = lock.withLock { clientChannel }
|
||||
return channel?.isActive ?: false
|
||||
return isChannelWritable(channel)
|
||||
}
|
||||
|
||||
fun createMessage(payload: ByteArray,
|
||||
@ -204,13 +305,17 @@ class AMQPClient(val targets: List<NetworkHostAndPort>,
|
||||
|
||||
fun write(msg: SendableMessage) {
|
||||
val channel = clientChannel
|
||||
if (channel == null) {
|
||||
if (channel == null || !isChannelWritable(channel)) {
|
||||
throw IllegalStateException("Connection to $targets not active")
|
||||
} else {
|
||||
channel.writeAndFlush(msg)
|
||||
}
|
||||
}
|
||||
|
||||
private fun isChannelWritable(channel: Channel?): Boolean {
|
||||
return channel?.let { channel.isOpen && channel.isActive && amqpActive } ?: false
|
||||
}
|
||||
|
||||
private val _onReceive = PublishSubject.create<ReceivedMessage>().toSerialized()
|
||||
val onReceive: Observable<ReceivedMessage>
|
||||
get() = _onReceive
|
||||
|
@ -2,7 +2,7 @@ package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent
|
||||
import net.corda.nodeapi.internal.config.CertificateStore
|
||||
import java.security.KeyStore
|
||||
import net.corda.nodeapi.internal.config.DEFAULT_SSL_HANDSHAKE_TIMEOUT_MILLIS
|
||||
|
||||
interface AMQPConfiguration {
|
||||
/**
|
||||
@ -32,12 +32,11 @@ interface AMQPConfiguration {
|
||||
val trustStore: CertificateStore
|
||||
|
||||
/**
|
||||
* Setting crlCheckSoftFail to true allows certificate paths where some leaf certificates do not contain cRLDistributionPoints
|
||||
* and also allows validation to continue if the CRL distribution server is not contactable.
|
||||
* Control how CRL check will be performed.
|
||||
*/
|
||||
@JvmDefault
|
||||
val crlCheckSoftFail: Boolean
|
||||
get() = true
|
||||
val revocationConfig: RevocationConfig
|
||||
get() = RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL)
|
||||
|
||||
/**
|
||||
* Enables full debug tracing of all netty and AMQP level packets. This logs aat very high volume and is only for developers.
|
||||
@ -51,5 +50,41 @@ interface AMQPConfiguration {
|
||||
* but currently that is deferred to Artemis and the bridge code.
|
||||
*/
|
||||
val maxMessageSize: Int
|
||||
|
||||
@JvmDefault
|
||||
val proxyConfig: ProxyConfig?
|
||||
get() = null
|
||||
|
||||
@JvmDefault
|
||||
val sourceX500Name: String?
|
||||
get() = null
|
||||
|
||||
/**
|
||||
* Whether to use the tcnative open/boring SSL provider or the default Java SSL provider
|
||||
*/
|
||||
@JvmDefault
|
||||
val useOpenSsl: Boolean
|
||||
get() = false
|
||||
|
||||
@JvmDefault
|
||||
val sslHandshakeTimeout: Long
|
||||
get() = DEFAULT_SSL_HANDSHAKE_TIMEOUT_MILLIS // Aligned with sun.security.provider.certpath.URICertStore.DEFAULT_CRL_CONNECT_TIMEOUT
|
||||
|
||||
/**
|
||||
* An optional Health Check Phrase which if passed through the channel will cause AMQP Server to echo it back instead of doing normal pipeline processing
|
||||
*/
|
||||
val healthCheckPhrase: String?
|
||||
get() = null
|
||||
|
||||
/**
|
||||
* An optional set of IPv4/IPv6 remote address strings which will be compared to the remote address of inbound connections and these will only log at TRACE level
|
||||
*/
|
||||
@JvmDefault
|
||||
val silencedIPs: Set<String>
|
||||
get() = emptySet()
|
||||
|
||||
@JvmDefault
|
||||
val enableSNI: Boolean
|
||||
get() = true
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import io.netty.bootstrap.ServerBootstrap
|
||||
import io.netty.channel.Channel
|
||||
import io.netty.channel.ChannelHandler
|
||||
import io.netty.channel.ChannelInitializer
|
||||
import io.netty.channel.ChannelOption
|
||||
import io.netty.channel.EventLoopGroup
|
||||
@ -14,6 +15,7 @@ import io.netty.util.internal.logging.InternalLoggerFactory
|
||||
import io.netty.util.internal.logging.Slf4JLoggerFactory
|
||||
import net.corda.core.utilities.NetworkHostAndPort
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.debug
|
||||
import net.corda.nodeapi.internal.protonwrapper.messages.ReceivedMessage
|
||||
import net.corda.nodeapi.internal.protonwrapper.messages.SendableMessage
|
||||
import net.corda.nodeapi.internal.protonwrapper.messages.impl.SendableMessageImpl
|
||||
@ -31,7 +33,6 @@ import kotlin.concurrent.withLock
|
||||
|
||||
/**
|
||||
* This create a socket acceptor instance that can receive possibly multiple AMQP connections.
|
||||
* As of now this is not used outside of testing, but in future it will be used for standalone bridging components.
|
||||
*/
|
||||
class AMQPServer(val hostName: String,
|
||||
val port: Int,
|
||||
@ -42,8 +43,10 @@ class AMQPServer(val hostName: String,
|
||||
InternalLoggerFactory.setDefaultFactory(Slf4JLoggerFactory.INSTANCE)
|
||||
}
|
||||
|
||||
private const val CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME = "net.corda.nodeapi.amqpserver.NumServerThreads"
|
||||
|
||||
private val log = contextLogger()
|
||||
const val NUM_SERVER_THREADS = 4
|
||||
private val NUM_SERVER_THREADS = Integer.getInteger(CORDA_AMQP_NUM_SERVER_THREAD_PROP_NAME, 4)
|
||||
}
|
||||
|
||||
private val lock = ReentrantLock()
|
||||
@ -60,29 +63,59 @@ class AMQPServer(val hostName: String,
|
||||
private val conf = parent.configuration
|
||||
|
||||
init {
|
||||
keyManagerFactory.init(conf.keyStore)
|
||||
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, conf.crlCheckSoftFail))
|
||||
keyManagerFactory.init(conf.keyStore.value.internal, conf.keyStore.entryPassword.toCharArray())
|
||||
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(conf.trustStore, conf.revocationConfig))
|
||||
}
|
||||
|
||||
override fun initChannel(ch: SocketChannel) {
|
||||
val amqpConfiguration = parent.configuration
|
||||
val pipeline = ch.pipeline()
|
||||
val handler = createServerSslHelper(keyManagerFactory, trustManagerFactory)
|
||||
pipeline.addLast("sslHandler", handler)
|
||||
amqpConfiguration.healthCheckPhrase?.let { pipeline.addLast(ModeSelectingChannel.NAME, ModeSelectingChannel(it)) }
|
||||
val (sslHandler, keyManagerFactoriesMap) = createSSLHandler(amqpConfiguration, ch)
|
||||
pipeline.addLast("sslHandler", sslHandler)
|
||||
if (conf.trace) pipeline.addLast("logger", LoggingHandler(LogLevel.INFO))
|
||||
val suppressLogs = ch.remoteAddress()?.hostString in amqpConfiguration.silencedIPs
|
||||
pipeline.addLast(AMQPChannelHandler(true,
|
||||
null,
|
||||
// Passing a mapping of legal names to key managers to be able to pick the correct one after
|
||||
// SNI completion event is fired up.
|
||||
keyManagerFactoriesMap,
|
||||
conf.userName,
|
||||
conf.password,
|
||||
conf.trace,
|
||||
{
|
||||
parent.clientChannels[it.first.remoteAddress()] = it.first
|
||||
parent._onConnection.onNext(it.second)
|
||||
suppressLogs,
|
||||
onOpen = { channel, change ->
|
||||
parent.run {
|
||||
clientChannels[channel.remoteAddress()] = channel
|
||||
_onConnection.onNext(change)
|
||||
}
|
||||
},
|
||||
{
|
||||
parent.clientChannels.remove(it.first.remoteAddress())
|
||||
parent._onConnection.onNext(it.second)
|
||||
onClose = { channel, change ->
|
||||
parent.run {
|
||||
val remoteAddress = channel.remoteAddress()
|
||||
clientChannels.remove(remoteAddress)
|
||||
_onConnection.onNext(change)
|
||||
}
|
||||
},
|
||||
{ rcv -> parent._onReceive.onNext(rcv) }))
|
||||
onReceive = { rcv -> parent._onReceive.onNext(rcv) }))
|
||||
}
|
||||
|
||||
private fun createSSLHandler(amqpConfig: AMQPConfiguration, ch: SocketChannel): Pair<ChannelHandler, Map<String, CertHoldingKeyManagerFactoryWrapper>> {
|
||||
return if (amqpConfig.useOpenSsl && amqpConfig.enableSNI && amqpConfig.keyStore.aliases().size > 1) {
|
||||
val keyManagerFactoriesMap = splitKeystore(amqpConfig)
|
||||
// SNI matching needed only when multiple nodes exist behind the server.
|
||||
Pair(createServerSNIOpenSslHandler(keyManagerFactoriesMap, trustManagerFactory), keyManagerFactoriesMap)
|
||||
} else {
|
||||
val keyManagerFactory = CertHoldingKeyManagerFactoryWrapper(keyManagerFactory, amqpConfig)
|
||||
val handler = if (amqpConfig.useOpenSsl) {
|
||||
createServerOpenSslHandler(keyManagerFactory, trustManagerFactory, ch.alloc())
|
||||
} else {
|
||||
// For javaSSL, SNI matching is handled at key manager level.
|
||||
createServerSslHandler(amqpConfig.keyStore, keyManagerFactory, trustManagerFactory)
|
||||
}
|
||||
handler.handshakeTimeoutMillis = amqpConfig.sslHandshakeTimeout
|
||||
Pair(handler, mapOf(DEFAULT to keyManagerFactory))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -95,7 +128,10 @@ class AMQPServer(val hostName: String,
|
||||
|
||||
val server = ServerBootstrap()
|
||||
// TODO Needs more configuration control when we profile. e.g. to use EPOLL on Linux
|
||||
server.group(bossGroup, workerGroup).channel(NioServerSocketChannel::class.java).option(ChannelOption.SO_BACKLOG, 100).handler(LoggingHandler(LogLevel.INFO)).childHandler(ServerChannelInitializer(this))
|
||||
server.group(bossGroup, workerGroup).channel(NioServerSocketChannel::class.java)
|
||||
.option(ChannelOption.SO_BACKLOG, 100)
|
||||
.handler(NettyServerEventLogger(LogLevel.INFO, configuration.silencedIPs))
|
||||
.childHandler(ServerChannelInitializer(this))
|
||||
|
||||
log.info("Try to bind $port")
|
||||
val channelFuture = server.bind(hostName, port).sync() // block/throw here as better to know we failed to claim port than carry on
|
||||
@ -144,7 +180,7 @@ class AMQPServer(val hostName: String,
|
||||
requireMessageSize(payload.size, configuration.maxMessageSize)
|
||||
val dest = InetSocketAddress(destinationLink.host, destinationLink.port)
|
||||
require(dest in clientChannels.keys) {
|
||||
"Destination not available"
|
||||
"Destination $dest is not available"
|
||||
}
|
||||
return SendableMessageImpl(payload, topic, destinationLegalName, destinationLink, properties)
|
||||
}
|
||||
@ -155,21 +191,22 @@ class AMQPServer(val hostName: String,
|
||||
if (channel == null) {
|
||||
throw IllegalStateException("Connection to ${msg.destinationLink} not active")
|
||||
} else {
|
||||
log.debug { "Writing message with payload of size ${msg.payload.size} into channel $channel" }
|
||||
channel.writeAndFlush(msg)
|
||||
log.debug { "Done writing message with payload of size ${msg.payload.size} into channel $channel" }
|
||||
}
|
||||
}
|
||||
|
||||
fun dropConnection(connectionRemoteHost: InetSocketAddress) {
|
||||
val channel = clientChannels[connectionRemoteHost]
|
||||
if (channel != null) {
|
||||
channel.close()
|
||||
}
|
||||
clientChannels[connectionRemoteHost]?.close()
|
||||
}
|
||||
|
||||
fun complete(delivery: Delivery, target: InetSocketAddress) {
|
||||
val channel = clientChannels[target]
|
||||
channel?.apply {
|
||||
log.debug { "Writing delivery $delivery into channel $channel" }
|
||||
writeAndFlush(delivery)
|
||||
log.debug { "Done writing delivery $delivery into channel $channel" }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,60 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import java.net.Socket
|
||||
import java.security.Principal
|
||||
import javax.net.ssl.SSLEngine
|
||||
import javax.net.ssl.X509ExtendedKeyManager
|
||||
import javax.net.ssl.X509KeyManager
|
||||
|
||||
interface AliasProvidingKeyMangerWrapper : X509KeyManager {
|
||||
var lastAlias: String?
|
||||
}
|
||||
|
||||
|
||||
class AliasProvidingKeyMangerWrapperImpl(private val keyManager: X509KeyManager) : AliasProvidingKeyMangerWrapper, X509KeyManager by keyManager {
|
||||
override var lastAlias: String? = null
|
||||
|
||||
override fun chooseServerAlias(keyType: String?, issuers: Array<out Principal>?, socket: Socket?): String? {
|
||||
return storeIfNotNull { keyManager.chooseServerAlias(keyType, issuers, socket) }
|
||||
}
|
||||
|
||||
override fun chooseClientAlias(keyType: Array<out String>?, issuers: Array<out Principal>?, socket: Socket?): String? {
|
||||
return storeIfNotNull { keyManager.chooseClientAlias(keyType, issuers, socket) }
|
||||
}
|
||||
|
||||
private fun storeIfNotNull(func: () -> String?): String? {
|
||||
val alias = func()
|
||||
if (alias != null) {
|
||||
lastAlias = alias
|
||||
}
|
||||
return alias
|
||||
}
|
||||
}
|
||||
|
||||
class AliasProvidingExtendedKeyMangerWrapper(private val keyManager: X509ExtendedKeyManager) : X509ExtendedKeyManager(), X509KeyManager by keyManager, AliasProvidingKeyMangerWrapper {
|
||||
override var lastAlias: String? = null
|
||||
|
||||
override fun chooseServerAlias(keyType: String?, issuers: Array<out Principal>?, socket: Socket?): String? {
|
||||
return storeIfNotNull { keyManager.chooseServerAlias(keyType, issuers, socket) }
|
||||
}
|
||||
|
||||
override fun chooseClientAlias(keyType: Array<out String>?, issuers: Array<out Principal>?, socket: Socket?): String? {
|
||||
return storeIfNotNull { keyManager.chooseClientAlias(keyType, issuers, socket) }
|
||||
}
|
||||
|
||||
override fun chooseEngineClientAlias(keyType: Array<out String>?, issuers: Array<out Principal>?, engine: SSLEngine?): String? {
|
||||
return storeIfNotNull { keyManager.chooseEngineClientAlias(keyType, issuers, engine) }
|
||||
}
|
||||
|
||||
override fun chooseEngineServerAlias(keyType: String?, issuers: Array<out Principal>?, engine: SSLEngine?): String? {
|
||||
return storeIfNotNull { keyManager.chooseEngineServerAlias(keyType, issuers, engine) }
|
||||
}
|
||||
|
||||
private fun storeIfNotNull(func: () -> String?): String? {
|
||||
val alias = func()
|
||||
if (alias != null) {
|
||||
lastAlias = alias
|
||||
}
|
||||
return alias
|
||||
}
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import net.corda.core.utilities.debug
|
||||
import org.slf4j.LoggerFactory
|
||||
import java.security.cert.CertPathValidatorException
|
||||
import java.security.cert.Certificate
|
||||
import java.security.cert.PKIXRevocationChecker
|
||||
import java.util.*
|
||||
|
||||
object AllowAllRevocationChecker : PKIXRevocationChecker() {
|
||||
|
||||
private val logger = LoggerFactory.getLogger(AllowAllRevocationChecker::class.java)
|
||||
|
||||
override fun check(cert: Certificate?, unresolvedCritExts: MutableCollection<String>?) {
|
||||
logger.debug {"Passing certificate check for: $cert"}
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
override fun isForwardCheckingSupported(): Boolean {
|
||||
return true
|
||||
}
|
||||
|
||||
override fun getSupportedExtensions(): MutableSet<String>? {
|
||||
return null
|
||||
}
|
||||
|
||||
override fun init(forward: Boolean) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
override fun getSoftFailExceptions(): MutableList<CertPathValidatorException> {
|
||||
return LinkedList()
|
||||
}
|
||||
}
|
@ -0,0 +1,81 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import java.security.KeyStore
|
||||
import java.security.cert.X509Certificate
|
||||
import javax.net.ssl.KeyManager
|
||||
import javax.net.ssl.KeyManagerFactory
|
||||
import javax.net.ssl.KeyManagerFactorySpi
|
||||
import javax.net.ssl.ManagerFactoryParameters
|
||||
import javax.net.ssl.X509ExtendedKeyManager
|
||||
import javax.net.ssl.X509KeyManager
|
||||
|
||||
class CertHoldingKeyManagerFactorySpiWrapper(private val factorySpi: KeyManagerFactorySpi, private val amqpConfig: AMQPConfiguration) : KeyManagerFactorySpi() {
|
||||
override fun engineInit(keyStore: KeyStore?, password: CharArray?) {
|
||||
val engineInitMethod = KeyManagerFactorySpi::class.java.getDeclaredMethod("engineInit", KeyStore::class.java, CharArray::class.java)
|
||||
engineInitMethod.isAccessible = true
|
||||
engineInitMethod.invoke(factorySpi, keyStore, password)
|
||||
}
|
||||
|
||||
override fun engineInit(spec: ManagerFactoryParameters?) {
|
||||
val engineInitMethod = KeyManagerFactorySpi::class.java.getDeclaredMethod("engineInit", ManagerFactoryParameters::class.java)
|
||||
engineInitMethod.isAccessible = true
|
||||
engineInitMethod.invoke(factorySpi, spec)
|
||||
}
|
||||
|
||||
private fun getKeyManagersImpl(): Array<KeyManager> {
|
||||
val engineGetKeyManagersMethod = KeyManagerFactorySpi::class.java.getDeclaredMethod("engineGetKeyManagers")
|
||||
engineGetKeyManagersMethod.isAccessible = true
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val keyManagers = engineGetKeyManagersMethod.invoke(factorySpi) as Array<KeyManager>
|
||||
return if (factorySpi is CertHoldingKeyManagerFactorySpiWrapper) keyManagers else keyManagers.map {
|
||||
val aliasProvidingKeyManager = getDefaultKeyManager(it)
|
||||
// Use the SNIKeyManager if keystore has several entries and only for clients and non-openSSL servers.
|
||||
// Condition of using SNIKeyManager: if its client, or JDKSsl server.
|
||||
val isClient = amqpConfig.sourceX500Name != null
|
||||
val enableSNI = amqpConfig.enableSNI && amqpConfig.keyStore.aliases().size > 1
|
||||
if (enableSNI && (isClient || !amqpConfig.useOpenSsl)) {
|
||||
SNIKeyManager(aliasProvidingKeyManager as X509ExtendedKeyManager, amqpConfig)
|
||||
} else {
|
||||
aliasProvidingKeyManager
|
||||
}
|
||||
}.toTypedArray()
|
||||
}
|
||||
|
||||
private fun getDefaultKeyManager(keyManager: KeyManager): KeyManager {
|
||||
return when (keyManager) {
|
||||
is X509ExtendedKeyManager -> AliasProvidingExtendedKeyMangerWrapper(keyManager)
|
||||
is X509KeyManager -> AliasProvidingKeyMangerWrapperImpl(keyManager)
|
||||
else -> throw UnsupportedOperationException("Supported key manager types are: X509ExtendedKeyManager, X509KeyManager. Provided ${keyManager::class.java.name}")
|
||||
}
|
||||
}
|
||||
|
||||
private val keyManagers = lazy { getKeyManagersImpl() }
|
||||
|
||||
override fun engineGetKeyManagers(): Array<KeyManager> {
|
||||
return keyManagers.value
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* You can wrap a key manager factory in this class if you need to get the cert chain currently used to identify or
|
||||
* verify. When using for TLS channels, make sure to wrap the (singleton) factory separately on each channel, as
|
||||
* the wrapper is not thread safe as in it will return the last used alias/cert chain and has itself no notion
|
||||
* of belonging to a certain channel.
|
||||
*/
|
||||
class CertHoldingKeyManagerFactoryWrapper(factory: KeyManagerFactory, amqpConfig: AMQPConfiguration) : KeyManagerFactory(getFactorySpi(factory, amqpConfig), factory.provider, factory.algorithm) {
|
||||
companion object {
|
||||
private fun getFactorySpi(factory: KeyManagerFactory, amqpConfig: AMQPConfiguration): KeyManagerFactorySpi {
|
||||
val spiField = KeyManagerFactory::class.java.getDeclaredField("factorySpi")
|
||||
spiField.isAccessible = true
|
||||
return CertHoldingKeyManagerFactorySpiWrapper(spiField.get(factory) as KeyManagerFactorySpi, amqpConfig)
|
||||
}
|
||||
}
|
||||
|
||||
fun getCurrentCertChain(): Array<out X509Certificate>? {
|
||||
val keyManager = keyManagers.firstOrNull()
|
||||
val alias = if (keyManager is AliasProvidingKeyMangerWrapper) keyManager.lastAlias else null
|
||||
return if (alias != null && keyManager is X509KeyManager) {
|
||||
keyManager.getCertificateChain(alias)
|
||||
} else null
|
||||
}
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import java.security.cert.X509CRL
|
||||
import java.security.cert.X509Certificate
|
||||
|
||||
interface ExternalCrlSource {
|
||||
|
||||
/**
|
||||
* Given certificate provides a set of CRLs, potentially performing remote communication.
|
||||
*/
|
||||
fun fetch(certificate: X509Certificate) : Set<X509CRL>
|
||||
}
|
@ -0,0 +1,76 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import io.netty.buffer.ByteBuf
|
||||
import io.netty.buffer.Unpooled
|
||||
import io.netty.channel.ChannelHandlerContext
|
||||
import io.netty.handler.codec.ByteToMessageDecoder
|
||||
import io.netty.handler.ssl.SslHandler
|
||||
import net.corda.core.utilities.contextLogger
|
||||
|
||||
/**
|
||||
* Responsible for deciding whether we are likely to be processing health probe request
|
||||
* or this is a normal SSL/AMQP processing pipeline
|
||||
*/
|
||||
internal class ModeSelectingChannel(healthCheckPhrase: String) : ByteToMessageDecoder() {
|
||||
|
||||
companion object {
|
||||
const val NAME = "modeSelector"
|
||||
private val log = contextLogger()
|
||||
}
|
||||
|
||||
private enum class TriState {
|
||||
UNDECIDED,
|
||||
ECHO_MODE,
|
||||
NORMAL_MODE
|
||||
}
|
||||
|
||||
private val healthCheckPhraseArray = healthCheckPhrase.toByteArray(Charsets.UTF_8)
|
||||
|
||||
private var currentMode = TriState.UNDECIDED
|
||||
|
||||
private var alreadyEchoedPos = 0
|
||||
|
||||
override fun decode(ctx: ChannelHandlerContext, inByteBuf: ByteBuf, out: MutableList<Any>?) {
|
||||
|
||||
fun ChannelHandlerContext.echoBack(inByteBuf: ByteBuf) {
|
||||
|
||||
// WriteAndFlush() will decrement count and will blow unless we retain first
|
||||
// And we have to ensure we are not sending the same information multiple times
|
||||
val toBeWritten = inByteBuf.retainedSlice(alreadyEchoedPos, inByteBuf.readableBytes() - alreadyEchoedPos)
|
||||
|
||||
writeAndFlush(toBeWritten)
|
||||
|
||||
alreadyEchoedPos = inByteBuf.readableBytes()
|
||||
}
|
||||
|
||||
if(currentMode == TriState.ECHO_MODE) {
|
||||
ctx.echoBack(inByteBuf)
|
||||
return
|
||||
}
|
||||
|
||||
// Wait until the length prefix is available.
|
||||
if (inByteBuf.readableBytes() < healthCheckPhraseArray.size) {
|
||||
return
|
||||
}
|
||||
|
||||
// Direct buffers do not allow calling `.array()` on them, see `io.netty.buffer.UnpooledDirectByteBuf.array`
|
||||
val incomingArray = Unpooled.copiedBuffer(inByteBuf).array()
|
||||
val zipped = healthCheckPhraseArray.zip(incomingArray)
|
||||
if (zipped.all { it.first == it.second }) {
|
||||
// Matched the healthCheckPhrase
|
||||
currentMode = TriState.ECHO_MODE
|
||||
log.info("Echo mode activated for connection ${ctx.channel().id()}")
|
||||
// Cancel scheduled action to avoid SSL handshake timeout, which starts "ticking" upon connection is established,
|
||||
// namely upon call to `io.netty.handler.ssl.SslHandler#handlerAdded` is made
|
||||
ctx.pipeline().get(SslHandler::class.java)?.handshakeFuture()?.cancel(false)
|
||||
ctx.echoBack(inByteBuf)
|
||||
} else {
|
||||
currentMode = TriState.NORMAL_MODE
|
||||
// Remove self from pipeline and replay all the messages received down the pipeline
|
||||
// It is important to bump-up reference count as pipeline removal decrements it by one.
|
||||
inByteBuf.retain()
|
||||
ctx.pipeline().remove(this)
|
||||
ctx.fireChannelRead(inByteBuf)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,73 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import io.netty.channel.ChannelDuplexHandler
|
||||
import io.netty.channel.ChannelHandler
|
||||
import io.netty.channel.ChannelHandlerContext
|
||||
import io.netty.channel.ChannelPromise
|
||||
import io.netty.handler.logging.LogLevel
|
||||
import io.netty.util.internal.logging.InternalLogLevel
|
||||
import io.netty.util.internal.logging.InternalLogger
|
||||
import io.netty.util.internal.logging.InternalLoggerFactory
|
||||
import java.net.SocketAddress
|
||||
|
||||
@ChannelHandler.Sharable
|
||||
class NettyServerEventLogger(level: LogLevel = DEFAULT_LEVEL, val silencedIPs: Set<String> = emptySet()) : ChannelDuplexHandler() {
|
||||
companion object {
|
||||
val DEFAULT_LEVEL: LogLevel = LogLevel.DEBUG
|
||||
}
|
||||
|
||||
private val logger: InternalLogger = InternalLoggerFactory.getInstance(javaClass)
|
||||
private val internalLevel: InternalLogLevel = level.toInternalLevel()
|
||||
|
||||
@Throws(Exception::class)
|
||||
override fun channelActive(ctx: ChannelHandlerContext) {
|
||||
if (logger.isEnabled(internalLevel)) {
|
||||
logger.log(internalLevel, "Server socket ${ctx.channel()} ACTIVE")
|
||||
}
|
||||
ctx.fireChannelActive()
|
||||
}
|
||||
|
||||
@Throws(Exception::class)
|
||||
override fun channelInactive(ctx: ChannelHandlerContext) {
|
||||
if (logger.isEnabled(internalLevel)) {
|
||||
logger.log(internalLevel, "Server socket ${ctx.channel()} INACTIVE")
|
||||
}
|
||||
ctx.fireChannelInactive()
|
||||
}
|
||||
|
||||
@Suppress("OverridingDeprecatedMember")
|
||||
@Throws(Exception::class)
|
||||
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
|
||||
if (logger.isEnabled(internalLevel)) {
|
||||
logger.log(internalLevel, "Server socket ${ctx.channel()} EXCEPTION ${cause.message}", cause)
|
||||
}
|
||||
ctx.fireExceptionCaught(cause)
|
||||
}
|
||||
|
||||
@Throws(Exception::class)
|
||||
override fun bind(ctx: ChannelHandlerContext, localAddress: SocketAddress, promise: ChannelPromise) {
|
||||
if (logger.isEnabled(internalLevel)) {
|
||||
logger.log(internalLevel, "Server socket ${ctx.channel()} BIND $localAddress")
|
||||
}
|
||||
ctx.bind(localAddress, promise)
|
||||
}
|
||||
|
||||
@Throws(Exception::class)
|
||||
override fun close(ctx: ChannelHandlerContext, promise: ChannelPromise) {
|
||||
if (logger.isEnabled(internalLevel)) {
|
||||
logger.log(internalLevel, "Server socket ${ctx.channel()} CLOSE")
|
||||
}
|
||||
ctx.close(promise)
|
||||
}
|
||||
|
||||
@Throws(Exception::class)
|
||||
override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
|
||||
val level = if (msg is io.netty.channel.socket.SocketChannel) { // Should always be the case as this is a server socket, but be defensive
|
||||
if (msg.remoteAddress()?.hostString !in silencedIPs) internalLevel else InternalLogLevel.TRACE
|
||||
} else internalLevel
|
||||
if (logger.isEnabled(level)) {
|
||||
logger.log(level, "Server socket ${ctx.channel()} ACCEPTED $msg")
|
||||
}
|
||||
ctx.fireChannelRead(msg)
|
||||
}
|
||||
}
|
@ -0,0 +1,83 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import com.typesafe.config.Config
|
||||
import net.corda.nodeapi.internal.config.ConfigParser
|
||||
import net.corda.nodeapi.internal.config.CustomConfigParser
|
||||
|
||||
/**
|
||||
* Data structure for controlling the way how Certificate Revocation Lists are handled.
|
||||
*/
|
||||
@CustomConfigParser(RevocationConfigParser::class)
|
||||
interface RevocationConfig {
|
||||
|
||||
enum class Mode {
|
||||
|
||||
/**
|
||||
* @see java.security.cert.PKIXRevocationChecker.Option.SOFT_FAIL
|
||||
*/
|
||||
SOFT_FAIL,
|
||||
|
||||
/**
|
||||
* Opposite of SOFT_FAIL - i.e. most rigorous check.
|
||||
* Among other things, this check requires that CRL checking URL is available on every level of certificate chain.
|
||||
* This is also known as Strict mode.
|
||||
*/
|
||||
HARD_FAIL,
|
||||
|
||||
/**
|
||||
* CRLs are obtained from external source
|
||||
* @see ExternalCrlSource
|
||||
*/
|
||||
EXTERNAL_SOURCE,
|
||||
|
||||
/**
|
||||
* Switch CRL check off.
|
||||
*/
|
||||
OFF
|
||||
}
|
||||
|
||||
val mode: Mode
|
||||
|
||||
/**
|
||||
* Optional `ExternalCrlSource` which only makes sense with `mode` = `EXTERNAL_SOURCE`
|
||||
*/
|
||||
val externalCrlSource: ExternalCrlSource?
|
||||
|
||||
/**
|
||||
* Creates a copy of `RevocationConfig` with ExternalCrlSource enriched
|
||||
*/
|
||||
fun enrichExternalCrlSource(sourceFunc: (() -> ExternalCrlSource)?): RevocationConfig
|
||||
}
|
||||
|
||||
/**
|
||||
* Maintained for legacy purposes to convert old style `crlCheckSoftFail`.
|
||||
*/
|
||||
fun Boolean.toRevocationConfig() = if(this) RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL) else RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL)
|
||||
|
||||
data class RevocationConfigImpl(override val mode: RevocationConfig.Mode, override val externalCrlSource: ExternalCrlSource? = null) : RevocationConfig {
|
||||
override fun enrichExternalCrlSource(sourceFunc: (() -> ExternalCrlSource)?): RevocationConfig {
|
||||
return if(mode != RevocationConfig.Mode.EXTERNAL_SOURCE) {
|
||||
this
|
||||
} else {
|
||||
assert(sourceFunc != null) { "There should be a way to obtain ExternalCrlSource" }
|
||||
copy(externalCrlSource = sourceFunc!!())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class RevocationConfigParser : ConfigParser<RevocationConfig> {
|
||||
override fun parse(config: Config): RevocationConfig {
|
||||
val oneAndTheOnly = "mode"
|
||||
val allKeys = config.entrySet().map { it.key }
|
||||
require(allKeys.size == 1 && allKeys.contains(oneAndTheOnly)) {"For RevocationConfig, it is expected to have '$oneAndTheOnly' property only. " +
|
||||
"Actual set of properties: $allKeys. Please check 'revocationConfig' section."}
|
||||
val mode = config.getString(oneAndTheOnly)
|
||||
return when (mode.toUpperCase()) {
|
||||
"SOFT_FAIL" -> RevocationConfigImpl(RevocationConfig.Mode.SOFT_FAIL)
|
||||
"HARD_FAIL" -> RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL)
|
||||
"EXTERNAL_SOURCE" -> RevocationConfigImpl(RevocationConfig.Mode.EXTERNAL_SOURCE, null) // null for now till `enrichExternalCrlSource` is called
|
||||
"OFF" -> RevocationConfigImpl(RevocationConfig.Mode.OFF)
|
||||
else -> throw IllegalArgumentException("Unsupported mode : '$mode'")
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,112 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.nodeapi.internal.config.CertificateStore
|
||||
import net.corda.nodeapi.internal.crypto.x509
|
||||
import org.slf4j.MDC
|
||||
import java.net.Socket
|
||||
import java.security.Principal
|
||||
import javax.net.ssl.SNIMatcher
|
||||
import javax.net.ssl.SSLEngine
|
||||
import javax.net.ssl.SSLSocket
|
||||
import javax.net.ssl.X509ExtendedKeyManager
|
||||
import javax.net.ssl.X509KeyManager
|
||||
|
||||
internal class SNIKeyManager(private val keyManager: X509ExtendedKeyManager, private val amqpConfig: AMQPConfiguration) : X509ExtendedKeyManager(), X509KeyManager by keyManager, AliasProvidingKeyMangerWrapper {
|
||||
|
||||
companion object {
|
||||
private val log = contextLogger()
|
||||
}
|
||||
|
||||
override var lastAlias: String? = null
|
||||
|
||||
private fun withMDC(block: () -> Unit) {
|
||||
val oldMDC = MDC.getCopyOfContextMap()
|
||||
try {
|
||||
MDC.put("lastAlias", lastAlias)
|
||||
MDC.put("isServer", amqpConfig.sourceX500Name.isNullOrEmpty().toString())
|
||||
MDC.put("sourceX500Name", amqpConfig.sourceX500Name)
|
||||
MDC.put("useOpenSSL", amqpConfig.useOpenSsl.toString())
|
||||
block()
|
||||
} finally {
|
||||
MDC.setContextMap(oldMDC)
|
||||
}
|
||||
}
|
||||
|
||||
private fun logDebugWithMDC(msg: () -> String) {
|
||||
if (log.isDebugEnabled) {
|
||||
withMDC { log.debug(msg()) }
|
||||
}
|
||||
}
|
||||
|
||||
override fun chooseClientAlias(keyType: Array<out String>, issuers: Array<out Principal>, socket: Socket): String? {
|
||||
return storeIfNotNull { chooseClientAlias(amqpConfig.keyStore, amqpConfig.sourceX500Name) }
|
||||
}
|
||||
|
||||
override fun chooseEngineClientAlias(keyType: Array<out String>, issuers: Array<out Principal>, engine: SSLEngine): String? {
|
||||
return storeIfNotNull { chooseClientAlias(amqpConfig.keyStore, amqpConfig.sourceX500Name) }
|
||||
}
|
||||
|
||||
override fun chooseServerAlias(keyType: String?, issuers: Array<out Principal>?, socket: Socket): String? {
|
||||
return storeIfNotNull {
|
||||
val matcher = (socket as SSLSocket).sslParameters.sniMatchers.first()
|
||||
chooseServerAlias(keyType, issuers, matcher)
|
||||
}
|
||||
}
|
||||
|
||||
override fun chooseEngineServerAlias(keyType: String?, issuers: Array<out Principal>?, engine: SSLEngine?): String? {
|
||||
return storeIfNotNull {
|
||||
val matcher = engine?.sslParameters?.sniMatchers?.first()
|
||||
chooseServerAlias(keyType, issuers, matcher)
|
||||
}
|
||||
}
|
||||
|
||||
private fun chooseServerAlias(keyType: String?, issuers: Array<out Principal>?, matcher: SNIMatcher?): String? {
|
||||
val aliases = keyManager.getServerAliases(keyType, issuers)
|
||||
if (aliases == null || aliases.isEmpty()) {
|
||||
logDebugWithMDC { "Keystore doesn't contain any aliases for key type $keyType and issuers $issuers." }
|
||||
return null
|
||||
}
|
||||
|
||||
log.debug("Checking aliases: $aliases.")
|
||||
matcher?.let {
|
||||
val matchedAlias = (it as ServerSNIMatcher).matchedAlias
|
||||
if (aliases.contains(matchedAlias)) {
|
||||
logDebugWithMDC { "Found match for $matchedAlias." }
|
||||
return matchedAlias
|
||||
}
|
||||
}
|
||||
|
||||
logDebugWithMDC { "Unable to find a matching alias." }
|
||||
return null
|
||||
}
|
||||
|
||||
private fun chooseClientAlias(keyStore: CertificateStore, clientLegalName: String?): String? {
|
||||
clientLegalName?.let {
|
||||
val aliases = keyStore.aliases()
|
||||
if (aliases.isEmpty()) {
|
||||
logDebugWithMDC { "Keystore doesn't contain any entries." }
|
||||
}
|
||||
aliases.forEach { alias ->
|
||||
val x500Name = keyStore[alias].x509.subjectX500Principal
|
||||
val aliasCordaX500Name = CordaX500Name.build(x500Name)
|
||||
val clientCordaX500Name = CordaX500Name.parse(it)
|
||||
if (clientCordaX500Name == aliasCordaX500Name) {
|
||||
logDebugWithMDC { "Found alias $alias for $clientCordaX500Name." }
|
||||
return alias
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
private fun storeIfNotNull(func: () -> String?): String? {
|
||||
val alias = func()
|
||||
if (alias != null) {
|
||||
lastAlias = alias
|
||||
}
|
||||
return alias
|
||||
}
|
||||
}
|
@ -1,56 +1,118 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import io.netty.buffer.ByteBufAllocator
|
||||
import io.netty.handler.ssl.ClientAuth
|
||||
import io.netty.handler.ssl.SniHandler
|
||||
import io.netty.handler.ssl.SslContextBuilder
|
||||
import io.netty.handler.ssl.SslHandler
|
||||
import io.netty.handler.ssl.SslProvider
|
||||
import io.netty.util.DomainNameMappingBuilder
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.newSecureRandom
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.internal.VisibleForTesting
|
||||
import net.corda.core.utilities.NetworkHostAndPort
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.core.utilities.toHex
|
||||
import net.corda.nodeapi.internal.ArtemisTcpTransport
|
||||
import net.corda.nodeapi.internal.config.CertificateStore
|
||||
import net.corda.nodeapi.internal.crypto.toBc
|
||||
import net.corda.nodeapi.internal.crypto.x509
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.revocation.ExternalSourceRevocationChecker
|
||||
import org.bouncycastle.asn1.ASN1InputStream
|
||||
import org.bouncycastle.asn1.DERIA5String
|
||||
import org.bouncycastle.asn1.DEROctetString
|
||||
import org.bouncycastle.asn1.x509.AuthorityKeyIdentifier
|
||||
import org.bouncycastle.asn1.x509.CRLDistPoint
|
||||
import org.bouncycastle.asn1.x509.DistributionPointName
|
||||
import org.bouncycastle.asn1.x509.Extension
|
||||
import org.bouncycastle.asn1.x509.GeneralName
|
||||
import org.bouncycastle.asn1.x509.GeneralNames
|
||||
import org.bouncycastle.asn1.x509.SubjectKeyIdentifier
|
||||
import org.slf4j.LoggerFactory
|
||||
import java.io.ByteArrayInputStream
|
||||
import java.net.Socket
|
||||
import java.security.KeyStore
|
||||
import java.security.cert.*
|
||||
import java.util.*
|
||||
import java.util.concurrent.Executor
|
||||
import javax.net.ssl.*
|
||||
import kotlin.system.measureTimeMillis
|
||||
|
||||
private const val HOSTNAME_FORMAT = "%s.corda.net"
|
||||
private const val SSL_HANDSHAKE_TIMEOUT_PROP_NAME = "corda.netty.sslHelper.handshakeTimeout"
|
||||
private const val DEFAULT_SSL_TIMEOUT = 20000 // Aligned with sun.security.provider.certpath.URICertStore.DEFAULT_CRL_CONNECT_TIMEOUT
|
||||
internal const val DEFAULT = "default"
|
||||
|
||||
internal class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() {
|
||||
internal const val DP_DEFAULT_ANSWER = "NO CRLDP ext"
|
||||
|
||||
internal val logger = LoggerFactory.getLogger("net.corda.nodeapi.internal.protonwrapper.netty.SSLHelper")
|
||||
|
||||
fun X509Certificate.distributionPoints() : Set<String>? {
|
||||
logger.debug("Checking CRLDPs for $subjectX500Principal")
|
||||
|
||||
val crldpExtBytes = getExtensionValue(Extension.cRLDistributionPoints.id)
|
||||
if (crldpExtBytes == null) {
|
||||
logger.debug(DP_DEFAULT_ANSWER)
|
||||
return emptySet()
|
||||
}
|
||||
|
||||
val derObjCrlDP = ASN1InputStream(ByteArrayInputStream(crldpExtBytes)).readObject()
|
||||
val dosCrlDP = derObjCrlDP as? DEROctetString
|
||||
if (dosCrlDP == null) {
|
||||
logger.error("Expected to have DEROctetString, actual type: ${derObjCrlDP.javaClass}")
|
||||
return emptySet()
|
||||
}
|
||||
val crldpExtOctetsBytes = dosCrlDP.octets
|
||||
val dpObj = ASN1InputStream(ByteArrayInputStream(crldpExtOctetsBytes)).readObject()
|
||||
val distPoint = CRLDistPoint.getInstance(dpObj)
|
||||
if (distPoint == null) {
|
||||
logger.error("Could not instantiate CRLDistPoint, from: $dpObj")
|
||||
return emptySet()
|
||||
}
|
||||
|
||||
val dpNames = distPoint.distributionPoints.mapNotNull { it.distributionPoint }.filter { it.type == DistributionPointName.FULL_NAME }
|
||||
val generalNames = dpNames.flatMap { GeneralNames.getInstance(it.name).names.asList() }
|
||||
return generalNames.filter { it.tagNo == GeneralName.uniformResourceIdentifier}.map { DERIA5String.getInstance(it.name).string }.toSet()
|
||||
}
|
||||
|
||||
fun X509Certificate.distributionPointsToString() : String {
|
||||
return with(distributionPoints()) {
|
||||
if(this == null || isEmpty()) {
|
||||
DP_DEFAULT_ANSWER
|
||||
} else {
|
||||
sorted().joinToString()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun certPathToString(certPath: Array<out X509Certificate>?): String {
|
||||
if (certPath == null) {
|
||||
return "<empty certpath>"
|
||||
}
|
||||
val certs = certPath.map {
|
||||
val bcCert = it.toBc()
|
||||
val subject = bcCert.subject.toString()
|
||||
val issuer = bcCert.issuer.toString()
|
||||
val keyIdentifier = try {
|
||||
SubjectKeyIdentifier.getInstance(bcCert.getExtension(Extension.subjectKeyIdentifier).parsedValue).keyIdentifier.toHex()
|
||||
} catch (ex: Exception) {
|
||||
"null"
|
||||
}
|
||||
val authorityKeyIdentifier = try {
|
||||
AuthorityKeyIdentifier.getInstance(bcCert.getExtension(Extension.authorityKeyIdentifier).parsedValue).keyIdentifier.toHex()
|
||||
} catch (ex: Exception) {
|
||||
"null"
|
||||
}
|
||||
" $subject[$keyIdentifier] issued by $issuer[$authorityKeyIdentifier] [${it.distributionPointsToString()}]"
|
||||
}
|
||||
return certs.joinToString("\r\n")
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager) : X509ExtendedTrustManager() {
|
||||
companion object {
|
||||
val log = contextLogger()
|
||||
}
|
||||
|
||||
private fun certPathToString(certPath: Array<out X509Certificate>?): String {
|
||||
if (certPath == null) {
|
||||
return "<empty certpath>"
|
||||
}
|
||||
val certs = certPath.map {
|
||||
val bcCert = it.toBc()
|
||||
val subject = bcCert.subject.toString()
|
||||
val issuer = bcCert.issuer.toString()
|
||||
val keyIdentifier = try {
|
||||
SubjectKeyIdentifier.getInstance(bcCert.getExtension(Extension.subjectKeyIdentifier).parsedValue).keyIdentifier.toHex()
|
||||
} catch (ex: Exception) {
|
||||
"null"
|
||||
}
|
||||
val authorityKeyIdentifier = try {
|
||||
AuthorityKeyIdentifier.getInstance(bcCert.getExtension(Extension.authorityKeyIdentifier).parsedValue).keyIdentifier.toHex()
|
||||
} catch (ex: Exception) {
|
||||
"null"
|
||||
}
|
||||
" $subject[$keyIdentifier] issued by $issuer[$authorityKeyIdentifier]"
|
||||
}
|
||||
return certs.joinToString("\r\n")
|
||||
}
|
||||
|
||||
|
||||
private fun certPathToStringFull(chain: Array<out X509Certificate>?): String {
|
||||
if (chain == null) {
|
||||
return "<empty certpath>"
|
||||
@ -107,6 +169,33 @@ internal class LoggingTrustManagerWrapper(val wrapped: X509ExtendedTrustManager)
|
||||
|
||||
}
|
||||
|
||||
private object LoggingImmediateExecutor : Executor {
|
||||
|
||||
override fun execute(command: Runnable?) {
|
||||
val log = LoggerFactory.getLogger(javaClass)
|
||||
|
||||
if (command == null) {
|
||||
log.error("SSL handler executor called with a null command")
|
||||
throw NullPointerException("command")
|
||||
}
|
||||
|
||||
@Suppress("TooGenericExceptionCaught", "MagicNumber") // log and rethrow all exceptions
|
||||
try {
|
||||
val commandName = command::class.qualifiedName?.let { "[$it]" } ?: ""
|
||||
log.debug("Entering SSL command $commandName")
|
||||
val elapsedTime = measureTimeMillis { command.run() }
|
||||
log.debug("Exiting SSL command $elapsedTime millis")
|
||||
if (elapsedTime > 100) {
|
||||
log.info("Command: $commandName took $elapsedTime millis to execute")
|
||||
}
|
||||
}
|
||||
catch (ex: Exception) {
|
||||
log.error("Caught exception in SSL handler executor", ex)
|
||||
throw ex
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal fun createClientSslHelper(target: NetworkHostAndPort,
|
||||
expectedRemoteLegalNames: Set<CordaX500Name>,
|
||||
keyManagerFactory: KeyManagerFactory,
|
||||
@ -125,13 +214,31 @@ internal fun createClientSslHelper(target: NetworkHostAndPort,
|
||||
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
|
||||
sslEngine.sslParameters = sslParameters
|
||||
}
|
||||
val sslHandler = SslHandler(sslEngine)
|
||||
sslHandler.handshakeTimeoutMillis = Integer.getInteger(SSL_HANDSHAKE_TIMEOUT_PROP_NAME, DEFAULT_SSL_TIMEOUT).toLong()
|
||||
return sslHandler
|
||||
@Suppress("DEPRECATION")
|
||||
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
|
||||
}
|
||||
|
||||
internal fun createServerSslHelper(keyManagerFactory: KeyManagerFactory,
|
||||
trustManagerFactory: TrustManagerFactory): SslHandler {
|
||||
internal fun createClientOpenSslHandler(target: NetworkHostAndPort,
|
||||
expectedRemoteLegalNames: Set<CordaX500Name>,
|
||||
keyManagerFactory: KeyManagerFactory,
|
||||
trustManagerFactory: TrustManagerFactory,
|
||||
alloc: ByteBufAllocator): SslHandler {
|
||||
val sslContext = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).keyManager(keyManagerFactory).trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory)).build()
|
||||
val sslEngine = sslContext.newEngine(alloc, target.host, target.port)
|
||||
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
|
||||
sslEngine.enabledCipherSuites = ArtemisTcpTransport.CIPHER_SUITES.toTypedArray()
|
||||
if (expectedRemoteLegalNames.size == 1) {
|
||||
val sslParameters = sslEngine.sslParameters
|
||||
sslParameters.serverNames = listOf(SNIHostName(x500toHostName(expectedRemoteLegalNames.single())))
|
||||
sslEngine.sslParameters = sslParameters
|
||||
}
|
||||
@Suppress("DEPRECATION")
|
||||
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
|
||||
}
|
||||
|
||||
internal fun createServerSslHandler(keyStore: CertificateStore,
|
||||
keyManagerFactory: KeyManagerFactory,
|
||||
trustManagerFactory: TrustManagerFactory): SslHandler {
|
||||
val sslContext = SSLContext.getInstance("TLS")
|
||||
val keyManagers = keyManagerFactory.keyManagers
|
||||
val trustManagers = trustManagerFactory.trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java).map { LoggingTrustManagerWrapper(it) }.toTypedArray()
|
||||
@ -142,35 +249,106 @@ internal fun createServerSslHelper(keyManagerFactory: KeyManagerFactory,
|
||||
sslEngine.enabledProtocols = ArtemisTcpTransport.TLS_VERSIONS.toTypedArray()
|
||||
sslEngine.enabledCipherSuites = ArtemisTcpTransport.CIPHER_SUITES.toTypedArray()
|
||||
sslEngine.enableSessionCreation = true
|
||||
val sslHandler = SslHandler(sslEngine)
|
||||
sslHandler.handshakeTimeoutMillis = Integer.getInteger(SSL_HANDSHAKE_TIMEOUT_PROP_NAME, DEFAULT_SSL_TIMEOUT).toLong()
|
||||
return sslHandler
|
||||
val sslParameters = sslEngine.sslParameters
|
||||
sslParameters.sniMatchers = listOf(ServerSNIMatcher(keyStore))
|
||||
sslEngine.sslParameters = sslParameters
|
||||
@Suppress("DEPRECATION")
|
||||
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
|
||||
}
|
||||
|
||||
internal fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore, crlCheckSoftFail: Boolean): ManagerFactoryParameters {
|
||||
val certPathBuilder = CertPathBuilder.getInstance("PKIX")
|
||||
val revocationChecker = certPathBuilder.revocationChecker as PKIXRevocationChecker
|
||||
revocationChecker.options = EnumSet.of(
|
||||
// Prefer CRL over OCSP
|
||||
PKIXRevocationChecker.Option.PREFER_CRLS,
|
||||
// Don't fall back to OCSP checking
|
||||
PKIXRevocationChecker.Option.NO_FALLBACK)
|
||||
if (crlCheckSoftFail) {
|
||||
// Allow revocation check to succeed if the revocation status cannot be determined for one of
|
||||
// the following reasons: The CRL or OCSP response cannot be obtained because of a network error.
|
||||
revocationChecker.options = revocationChecker.options + PKIXRevocationChecker.Option.SOFT_FAIL
|
||||
}
|
||||
@VisibleForTesting
|
||||
fun initialiseTrustStoreAndEnableCrlChecking(trustStore: CertificateStore, revocationConfig: RevocationConfig): ManagerFactoryParameters {
|
||||
val pkixParams = PKIXBuilderParameters(trustStore.value.internal, X509CertSelector())
|
||||
val revocationChecker = when (revocationConfig.mode) {
|
||||
RevocationConfig.Mode.OFF -> AllowAllRevocationChecker // Custom PKIXRevocationChecker skipping CRL check
|
||||
RevocationConfig.Mode.EXTERNAL_SOURCE -> {
|
||||
require(revocationConfig.externalCrlSource != null) { "externalCrlSource must not be null" }
|
||||
ExternalSourceRevocationChecker(revocationConfig.externalCrlSource!!) { Date() } // Custom PKIXRevocationChecker which uses `externalCrlSource`
|
||||
}
|
||||
else -> {
|
||||
val certPathBuilder = CertPathBuilder.getInstance("PKIX")
|
||||
val pkixRevocationChecker = certPathBuilder.revocationChecker as PKIXRevocationChecker
|
||||
pkixRevocationChecker.options = EnumSet.of(
|
||||
// Prefer CRL over OCSP
|
||||
PKIXRevocationChecker.Option.PREFER_CRLS,
|
||||
// Don't fall back to OCSP checking
|
||||
PKIXRevocationChecker.Option.NO_FALLBACK)
|
||||
if (revocationConfig.mode == RevocationConfig.Mode.SOFT_FAIL) {
|
||||
// Allow revocation check to succeed if the revocation status cannot be determined for one of
|
||||
// the following reasons: The CRL or OCSP response cannot be obtained because of a network error.
|
||||
pkixRevocationChecker.options = pkixRevocationChecker.options + PKIXRevocationChecker.Option.SOFT_FAIL
|
||||
}
|
||||
pkixRevocationChecker
|
||||
}
|
||||
}
|
||||
pkixParams.addCertPathChecker(revocationChecker)
|
||||
return CertPathTrustManagerParameters(pkixParams)
|
||||
}
|
||||
|
||||
internal fun createServerOpenSslHandler(keyManagerFactory: KeyManagerFactory,
|
||||
trustManagerFactory: TrustManagerFactory,
|
||||
alloc: ByteBufAllocator): SslHandler {
|
||||
|
||||
val sslContext = getServerSslContextBuilder(keyManagerFactory, trustManagerFactory).build()
|
||||
val sslEngine = sslContext.newEngine(alloc)
|
||||
sslEngine.useClientMode = false
|
||||
@Suppress("DEPRECATION")
|
||||
return SslHandler(sslEngine, false, LoggingImmediateExecutor)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a special SNI handler used only when openSSL is used for AMQPServer
|
||||
*/
|
||||
internal fun createServerSNIOpenSslHandler(keyManagerFactoriesMap: Map<String, KeyManagerFactory>,
|
||||
trustManagerFactory: TrustManagerFactory): SniHandler {
|
||||
|
||||
// Default value can be any in the map.
|
||||
val sslCtxBuilder = getServerSslContextBuilder(keyManagerFactoriesMap.values.first(), trustManagerFactory)
|
||||
val mapping = DomainNameMappingBuilder(sslCtxBuilder.build())
|
||||
keyManagerFactoriesMap.forEach {
|
||||
mapping.add(it.key, sslCtxBuilder.keyManager(it.value).build())
|
||||
}
|
||||
return SniHandler(mapping.build())
|
||||
}
|
||||
|
||||
@Suppress("SpreadOperator")
|
||||
private fun getServerSslContextBuilder(keyManagerFactory: KeyManagerFactory, trustManagerFactory: TrustManagerFactory): SslContextBuilder {
|
||||
return SslContextBuilder.forServer(keyManagerFactory)
|
||||
.sslProvider(SslProvider.OPENSSL)
|
||||
.trustManager(LoggingTrustManagerFactoryWrapper(trustManagerFactory))
|
||||
.clientAuth(ClientAuth.REQUIRE)
|
||||
.ciphers(ArtemisTcpTransport.CIPHER_SUITES)
|
||||
.protocols(*ArtemisTcpTransport.TLS_VERSIONS.toTypedArray())
|
||||
}
|
||||
|
||||
internal fun splitKeystore(config: AMQPConfiguration): Map<String, CertHoldingKeyManagerFactoryWrapper> {
|
||||
val keyStore = config.keyStore.value.internal
|
||||
val password = config.keyStore.entryPassword.toCharArray()
|
||||
return keyStore.aliases().toList().map { alias ->
|
||||
val key = keyStore.getKey(alias, password)
|
||||
val certs = keyStore.getCertificateChain(alias)
|
||||
val x500Name = keyStore.getCertificate(alias).x509.subjectX500Principal
|
||||
val cordaX500Name = CordaX500Name.build(x500Name)
|
||||
val newKeyStore = KeyStore.getInstance("JKS")
|
||||
newKeyStore.load(null)
|
||||
newKeyStore.setKeyEntry(alias, key, password, certs)
|
||||
val newKeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
|
||||
newKeyManagerFactory.init(newKeyStore, password)
|
||||
x500toHostName(cordaX500Name) to CertHoldingKeyManagerFactoryWrapper(newKeyManagerFactory, config)
|
||||
}.toMap()
|
||||
}
|
||||
|
||||
// As per Javadoc in: https://docs.oracle.com/javase/8/docs/api/javax/net/ssl/KeyManagerFactory.html `init` method
|
||||
// 2nd parameter `password` - the password for recovering keys in the KeyStore
|
||||
fun KeyManagerFactory.init(keyStore: CertificateStore) = init(keyStore.value.internal, keyStore.entryPassword.toCharArray())
|
||||
|
||||
fun TrustManagerFactory.init(trustStore: CertificateStore) = init(trustStore.value.internal)
|
||||
|
||||
/**
|
||||
* Method that converts a [CordaX500Name] to a a valid hostname (RFC-1035). It's used for SNI to indicate the target
|
||||
* when trying to communicate with nodes that reside behind the same firewall. This is a solution to TLS's extension not
|
||||
* yet supporting x500 names as server names
|
||||
*/
|
||||
internal fun x500toHostName(x500Name: CordaX500Name): String {
|
||||
val secureHash = SecureHash.sha256(x500Name.toString())
|
||||
// RFC 1035 specifies a limit 255 bytes for hostnames with each label being 63 bytes or less. Due to this, the string
|
||||
|
@ -0,0 +1,47 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.nodeapi.internal.config.CertificateStore
|
||||
import net.corda.nodeapi.internal.crypto.x509
|
||||
import javax.net.ssl.SNIHostName
|
||||
import javax.net.ssl.SNIMatcher
|
||||
import javax.net.ssl.SNIServerName
|
||||
import javax.net.ssl.StandardConstants
|
||||
|
||||
class ServerSNIMatcher(private val keyStore: CertificateStore) : SNIMatcher(0) {
|
||||
|
||||
companion object {
|
||||
val log = contextLogger()
|
||||
}
|
||||
|
||||
var matchedAlias: String? = null
|
||||
private set
|
||||
var matchedServerName: String? = null
|
||||
private set
|
||||
|
||||
override fun matches(serverName: SNIServerName): Boolean {
|
||||
if (serverName.type == StandardConstants.SNI_HOST_NAME) {
|
||||
keyStore.aliases().forEach { alias ->
|
||||
val x500Name = keyStore[alias].x509.subjectX500Principal
|
||||
val cordaX500Name = CordaX500Name.build(x500Name)
|
||||
// Convert the CordaX500Name into the expected host name and compare
|
||||
// E.g. O=Corda B, L=London, C=GB becomes 3c6dd991936308edb210555103ffc1bb.corda.net
|
||||
if ((serverName as SNIHostName).asciiName == x500toHostName(cordaX500Name)) {
|
||||
matchedAlias = alias
|
||||
matchedServerName = serverName.asciiName
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val knownSNIValues = keyStore.aliases().joinToString {
|
||||
val x500Name = keyStore[it].x509.subjectX500Principal
|
||||
val cordaX500Name = CordaX500Name.build(x500Name)
|
||||
"hostname = ${x500toHostName(cordaX500Name)} alias = $it"
|
||||
}
|
||||
val requestedSNIValue = "hostname = ${(serverName as SNIHostName).asciiName}"
|
||||
log.warn("The requested SNI value [$requestedSNIValue] does not match any of the following known SNI values [$knownSNIValues]")
|
||||
return false
|
||||
}
|
||||
}
|
@ -0,0 +1,40 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty
|
||||
|
||||
import java.security.KeyStore
|
||||
import javax.net.ssl.ManagerFactoryParameters
|
||||
import javax.net.ssl.TrustManager
|
||||
import javax.net.ssl.TrustManagerFactory
|
||||
import javax.net.ssl.TrustManagerFactorySpi
|
||||
import javax.net.ssl.X509ExtendedTrustManager
|
||||
|
||||
class LoggingTrustManagerFactorySpiWrapper(private val factorySpi: TrustManagerFactorySpi) : TrustManagerFactorySpi() {
|
||||
override fun engineGetTrustManagers(): Array<TrustManager> {
|
||||
val engineGetTrustManagersMethod = TrustManagerFactorySpi::class.java.getDeclaredMethod("engineGetTrustManagers")
|
||||
engineGetTrustManagersMethod.isAccessible = true
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val trustManagers = engineGetTrustManagersMethod.invoke(factorySpi) as Array<TrustManager>
|
||||
return if (factorySpi is LoggingTrustManagerFactorySpiWrapper) trustManagers else trustManagers.filterIsInstance(X509ExtendedTrustManager::class.java).map { LoggingTrustManagerWrapper(it) }.toTypedArray()
|
||||
}
|
||||
|
||||
override fun engineInit(ks: KeyStore?) {
|
||||
val engineInitMethod = TrustManagerFactorySpi::class.java.getDeclaredMethod("engineInit", KeyStore::class.java)
|
||||
engineInitMethod.isAccessible = true
|
||||
engineInitMethod.invoke(factorySpi, ks)
|
||||
}
|
||||
|
||||
override fun engineInit(spec: ManagerFactoryParameters?) {
|
||||
val engineInitMethod = TrustManagerFactorySpi::class.java.getDeclaredMethod("engineInit", ManagerFactoryParameters::class.java)
|
||||
engineInitMethod.isAccessible = true
|
||||
engineInitMethod.invoke(factorySpi, spec)
|
||||
}
|
||||
}
|
||||
|
||||
class LoggingTrustManagerFactoryWrapper(factory: TrustManagerFactory) : TrustManagerFactory(getFactorySpi(factory), factory.provider, factory.algorithm) {
|
||||
companion object {
|
||||
private fun getFactorySpi(factory: TrustManagerFactory): TrustManagerFactorySpi {
|
||||
val spiField = TrustManagerFactory::class.java.getDeclaredField("factorySpi")
|
||||
spiField.isAccessible = true
|
||||
return LoggingTrustManagerFactorySpiWrapper(spiField.get(factory) as TrustManagerFactorySpi)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,88 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty.revocation
|
||||
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.ExternalCrlSource
|
||||
import org.bouncycastle.asn1.x509.Extension
|
||||
import java.security.cert.CRLReason
|
||||
import java.security.cert.CertPathValidatorException
|
||||
import java.security.cert.Certificate
|
||||
import java.security.cert.CertificateRevokedException
|
||||
import java.security.cert.PKIXRevocationChecker
|
||||
import java.security.cert.X509CRL
|
||||
import java.security.cert.X509Certificate
|
||||
import java.util.*
|
||||
|
||||
/**
|
||||
* Implementation of [PKIXRevocationChecker] which determines whether certificate is revoked using [externalCrlSource] which knows how to
|
||||
* obtain a set of CRLs for a given certificate from an external source
|
||||
*/
|
||||
class ExternalSourceRevocationChecker(private val externalCrlSource: ExternalCrlSource, private val dateSource: () -> Date) : PKIXRevocationChecker() {
|
||||
|
||||
companion object {
|
||||
private val logger = contextLogger()
|
||||
}
|
||||
|
||||
override fun check(cert: Certificate, unresolvedCritExts: MutableCollection<String>?) {
|
||||
val x509Certificate = cert as X509Certificate
|
||||
checkApprovedCRLs(x509Certificate, externalCrlSource.fetch(x509Certificate))
|
||||
}
|
||||
|
||||
/**
|
||||
* Borrowed from `RevocationChecker.checkApprovedCRLs()`
|
||||
*/
|
||||
@Suppress("NestedBlockDepth")
|
||||
@Throws(CertPathValidatorException::class)
|
||||
private fun checkApprovedCRLs(cert: X509Certificate, approvedCRLs: Set<X509CRL>) {
|
||||
// See if the cert is in the set of approved crls.
|
||||
logger.debug("ExternalSourceRevocationChecker.checkApprovedCRLs() cert SN: ${cert.serialNumber}")
|
||||
|
||||
for (crl in approvedCRLs) {
|
||||
val entry = crl.getRevokedCertificate(cert)
|
||||
if (entry != null) {
|
||||
logger.debug("ExternalSourceRevocationChecker.checkApprovedCRLs() CRL entry: $entry")
|
||||
|
||||
/*
|
||||
* Abort CRL validation and throw exception if there are any
|
||||
* unrecognized critical CRL entry extensions (see section
|
||||
* 5.3 of RFC 5280).
|
||||
*/
|
||||
val unresCritExts = entry.criticalExtensionOIDs
|
||||
if (unresCritExts != null && !unresCritExts.isEmpty()) {
|
||||
/* remove any that we will process */
|
||||
unresCritExts.remove(Extension.cRLDistributionPoints.id)
|
||||
unresCritExts.remove(Extension.certificateIssuer.id)
|
||||
if (!unresCritExts.isEmpty()) {
|
||||
throw CertPathValidatorException(
|
||||
"Unrecognized critical extension(s) in revoked CRL entry: $unresCritExts")
|
||||
}
|
||||
}
|
||||
|
||||
val reasonCode = entry.revocationReason ?: CRLReason.UNSPECIFIED
|
||||
val revocationDate = entry.revocationDate
|
||||
if (revocationDate.before(dateSource())) {
|
||||
val t = CertificateRevokedException(
|
||||
revocationDate, reasonCode,
|
||||
crl.issuerX500Principal, mutableMapOf())
|
||||
throw CertPathValidatorException(
|
||||
t.message, t, null, -1, CertPathValidatorException.BasicReason.REVOKED)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun isForwardCheckingSupported(): Boolean {
|
||||
return true
|
||||
}
|
||||
|
||||
override fun getSupportedExtensions(): MutableSet<String>? {
|
||||
return null
|
||||
}
|
||||
|
||||
override fun init(forward: Boolean) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
override fun getSoftFailExceptions(): MutableList<CertPathValidatorException> {
|
||||
return LinkedList()
|
||||
}
|
||||
}
|
@ -66,7 +66,7 @@ class AMQPClientSerializationScheme(
|
||||
}
|
||||
}
|
||||
|
||||
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean {
|
||||
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: UseCase): Boolean {
|
||||
return magic == amqpMagic && (target == UseCase.RPCClient || target == UseCase.P2P)
|
||||
}
|
||||
|
||||
|
@ -18,10 +18,16 @@ class AMQPServerSerializationScheme(
|
||||
cordappSerializationWhitelists: Set<SerializationWhitelist>,
|
||||
serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory>
|
||||
) : AbstractAMQPSerializationScheme(cordappCustomSerializers, cordappSerializationWhitelists, serializerFactoriesForContexts) {
|
||||
constructor(cordapps: List<Cordapp>) : this(cordapps.customSerializers, cordapps.serializationWhitelists, AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised())
|
||||
constructor(cordapps: List<Cordapp>, serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory>) : this(cordapps.customSerializers, cordapps.serializationWhitelists, serializerFactoriesForContexts)
|
||||
constructor(cordapps: List<Cordapp>) : this(cordapps.customSerializers, cordapps.serializationWhitelists)
|
||||
constructor(cordapps: List<Cordapp>, serializerFactoriesForContexts: MutableMap<SerializationFactoryCacheKey, SerializerFactory>)
|
||||
: this(cordapps.customSerializers, cordapps.serializationWhitelists, serializerFactoriesForContexts)
|
||||
constructor(
|
||||
cordappCustomSerializers: Set<SerializationCustomSerializer<*,*>>,
|
||||
cordappSerializationWhitelists: Set<SerializationWhitelist>
|
||||
) : this(cordappCustomSerializers, cordappSerializationWhitelists, AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised())
|
||||
|
||||
constructor() : this(emptySet(), emptySet(), AccessOrderLinkedHashMap<SerializationFactoryCacheKey, SerializerFactory>(128).toSynchronised() )
|
||||
@Suppress("UNUSED")
|
||||
constructor() : this(emptySet(), emptySet())
|
||||
|
||||
override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory {
|
||||
throw UnsupportedOperationException()
|
||||
|
@ -0,0 +1,136 @@
|
||||
package net.corda.nodeapi.internal.serialization.kryo
|
||||
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.esotericsoftware.kryo.Serializer
|
||||
import com.esotericsoftware.kryo.io.Input
|
||||
import com.esotericsoftware.kryo.io.Output
|
||||
import java.lang.reflect.Constructor
|
||||
import java.lang.reflect.Field
|
||||
import java.util.LinkedHashMap
|
||||
import java.util.LinkedHashSet
|
||||
import java.util.LinkedList
|
||||
|
||||
/**
|
||||
* The [LinkedHashMap] and [LinkedHashSet] have a problem with the default Quasar/Kryo serialisation
|
||||
* in that serialising an iterator (and subsequent [LinkedHashMap.Entry]) over a sufficiently large
|
||||
* data set can lead to a stack overflow (because the object map is traversed recursively).
|
||||
*
|
||||
* We've added our own custom serializer in order to ensure that the iterator is correctly deserialized.
|
||||
*/
|
||||
internal object LinkedHashMapIteratorSerializer : Serializer<Iterator<*>>() {
|
||||
private val DUMMY_MAP = linkedMapOf(1L to 1)
|
||||
private val outerMapField: Field = getIterator()::class.java.superclass.getDeclaredField("this$0").apply { isAccessible = true }
|
||||
private val currentField: Field = getIterator()::class.java.superclass.getDeclaredField("current").apply { isAccessible = true }
|
||||
|
||||
private val KEY_ITERATOR_CLASS: Class<MutableIterator<Long>> = DUMMY_MAP.keys.iterator().javaClass
|
||||
private val VALUE_ITERATOR_CLASS: Class<MutableIterator<Int>> = DUMMY_MAP.values.iterator().javaClass
|
||||
private val MAP_ITERATOR_CLASS: Class<MutableIterator<MutableMap.MutableEntry<Long, Int>>> = DUMMY_MAP.iterator().javaClass
|
||||
|
||||
fun getIterator(): Any = DUMMY_MAP.iterator()
|
||||
|
||||
override fun write(kryo: Kryo, output: Output, obj: Iterator<*>) {
|
||||
val current: Map.Entry<*, *>? = currentField.get(obj) as Map.Entry<*, *>?
|
||||
kryo.writeClassAndObject(output, outerMapField.get(obj))
|
||||
kryo.writeClassAndObject(output, current)
|
||||
}
|
||||
|
||||
override fun read(kryo: Kryo, input: Input, type: Class<Iterator<*>>): Iterator<*> {
|
||||
val outerMap = kryo.readClassAndObject(input) as Map<*, *>
|
||||
return when (type) {
|
||||
KEY_ITERATOR_CLASS -> {
|
||||
val current = (kryo.readClassAndObject(input) as? Map.Entry<*, *>)?.key
|
||||
outerMap.keys.iterator().returnToIteratorLocation(kryo, current)
|
||||
}
|
||||
VALUE_ITERATOR_CLASS -> {
|
||||
val current = (kryo.readClassAndObject(input) as? Map.Entry<*, *>)?.value
|
||||
outerMap.values.iterator().returnToIteratorLocation(kryo, current)
|
||||
}
|
||||
MAP_ITERATOR_CLASS -> {
|
||||
val current = (kryo.readClassAndObject(input) as? Map.Entry<*, *>)
|
||||
outerMap.iterator().returnToIteratorLocation(kryo, current)
|
||||
}
|
||||
else -> throw IllegalStateException("Invalid type")
|
||||
}
|
||||
}
|
||||
|
||||
private fun Iterator<*>.returnToIteratorLocation(kryo: Kryo, current: Any?): Iterator<*> {
|
||||
while (this.hasNext()) {
|
||||
val key = this.next()
|
||||
if (iteratedObjectsEqual(kryo, key, current)) break
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
private fun iteratedObjectsEqual(kryo: Kryo, a: Any?, b: Any?): Boolean = if (a == null || b == null) {
|
||||
a == b
|
||||
} else {
|
||||
a === b || mapEntriesEqual(kryo, a, b) || kryoOptimisesAwayReferencesButEqual(kryo, a, b)
|
||||
}
|
||||
|
||||
/**
|
||||
* Kryo can substitute brand new created instances for some types during deserialization, making the identity check fail.
|
||||
* Fall back to equality for those.
|
||||
*/
|
||||
private fun kryoOptimisesAwayReferencesButEqual(kryo: Kryo, a: Any, b: Any) =
|
||||
(!kryo.referenceResolver.useReferences(a.javaClass) && !kryo.referenceResolver.useReferences(b.javaClass) && a == b)
|
||||
|
||||
private fun mapEntriesEqual(kryo: Kryo, a: Any, b: Any) =
|
||||
(a is Map.Entry<*, *> && b is Map.Entry<*, *> && iteratedObjectsEqual(kryo, a.key, b.key))
|
||||
}
|
||||
|
||||
/**
|
||||
* The [LinkedHashMap] and [LinkedHashSet] have a problem with the default Quasar/Kryo serialisation
|
||||
* in that serialising an iterator (and subsequent [LinkedHashMap.Entry]) over a sufficiently large
|
||||
* data set can lead to a stack overflow (because the object map is traversed recursively).
|
||||
*
|
||||
* We've added our own custom serializer in order to ensure that only the key/value are recorded.
|
||||
* The rest of the list isn't required at this scope.
|
||||
*/
|
||||
object LinkedHashMapEntrySerializer : Serializer<Map.Entry<*, *>>() {
|
||||
// Create a dummy map so that we can get the LinkedHashMap$Entry from it
|
||||
// The element type of the map doesn't matter. The entry is all we want
|
||||
private val DUMMY_MAP = linkedMapOf(1L to 1)
|
||||
fun getEntry(): Any = DUMMY_MAP.entries.first()
|
||||
private val constr: Constructor<*> = getEntry()::class.java.declaredConstructors.single().apply { isAccessible = true }
|
||||
|
||||
/**
|
||||
* Kryo would end up serialising "this" entry, then serialise "this.after" recursively, leading to a very large stack.
|
||||
* we'll skip that and just write out the key/value
|
||||
*/
|
||||
override fun write(kryo: Kryo, output: Output, obj: Map.Entry<*, *>) {
|
||||
val e: Map.Entry<*, *> = obj
|
||||
kryo.writeClassAndObject(output, e.key)
|
||||
kryo.writeClassAndObject(output, e.value)
|
||||
}
|
||||
|
||||
override fun read(kryo: Kryo, input: Input, type: Class<Map.Entry<*, *>>): Map.Entry<*, *> {
|
||||
val key = kryo.readClassAndObject(input)
|
||||
val value = kryo.readClassAndObject(input)
|
||||
return constr.newInstance(0, key, value, null) as Map.Entry<*, *>
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Also, add a [ListIterator] serializer to avoid more linked list issues.
|
||||
*/
|
||||
object LinkedListItrSerializer : Serializer<ListIterator<Any>>() {
|
||||
// Create a dummy list so that we can get the ListItr from it
|
||||
// The element type of the list doesn't matter. The iterator is all we want
|
||||
private val DUMMY_LIST = LinkedList<Long>(listOf(1))
|
||||
fun getListItr(): Any = DUMMY_LIST.listIterator()
|
||||
|
||||
private val outerListField: Field = getListItr()::class.java.getDeclaredField("this$0").apply { isAccessible = true }
|
||||
|
||||
override fun write(kryo: Kryo, output: Output, obj: ListIterator<Any>) {
|
||||
kryo.writeClassAndObject(output, outerListField.get(obj))
|
||||
output.writeInt(obj.nextIndex())
|
||||
}
|
||||
|
||||
override fun read(kryo: Kryo, input: Input, type: Class<ListIterator<Any>>): ListIterator<Any> {
|
||||
val list = kryo.readClassAndObject(input) as LinkedList<*>
|
||||
val index = input.readInt()
|
||||
return list.listIterator(index)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -10,7 +10,11 @@ import com.esotericsoftware.kryo.serializers.FieldSerializer
|
||||
import de.javakaffee.kryoserializers.ArraysAsListSerializer
|
||||
import de.javakaffee.kryoserializers.BitSetSerializer
|
||||
import de.javakaffee.kryoserializers.UnmodifiableCollectionsSerializer
|
||||
import de.javakaffee.kryoserializers.guava.*
|
||||
import de.javakaffee.kryoserializers.guava.ImmutableListSerializer
|
||||
import de.javakaffee.kryoserializers.guava.ImmutableMapSerializer
|
||||
import de.javakaffee.kryoserializers.guava.ImmutableMultimapSerializer
|
||||
import de.javakaffee.kryoserializers.guava.ImmutableSetSerializer
|
||||
import de.javakaffee.kryoserializers.guava.ImmutableSortedSetSerializer
|
||||
import net.corda.core.contracts.ContractAttachment
|
||||
import net.corda.core.contracts.ContractClassName
|
||||
import net.corda.core.contracts.PrivacySalt
|
||||
@ -24,7 +28,11 @@ import net.corda.core.serialization.MissingAttachmentsException
|
||||
import net.corda.core.serialization.SerializationWhitelist
|
||||
import net.corda.core.serialization.SerializeAsToken
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.transactions.*
|
||||
import net.corda.core.transactions.ContractUpgradeFilteredTransaction
|
||||
import net.corda.core.transactions.ContractUpgradeWireTransaction
|
||||
import net.corda.core.transactions.NotaryChangeWireTransaction
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.core.transactions.WireTransaction
|
||||
import net.corda.core.utilities.NonEmptySet
|
||||
import net.corda.core.utilities.toNonEmptySet
|
||||
import net.corda.serialization.internal.DefaultWhitelist
|
||||
@ -51,8 +59,9 @@ import java.security.PrivateKey
|
||||
import java.security.PublicKey
|
||||
import java.security.cert.CertPath
|
||||
import java.security.cert.X509Certificate
|
||||
import java.util.*
|
||||
import kotlin.collections.ArrayList
|
||||
import java.util.Arrays
|
||||
import java.util.BitSet
|
||||
import java.util.ServiceLoader
|
||||
|
||||
object DefaultKryoCustomizer {
|
||||
private val serializationWhitelists: List<SerializationWhitelist> by lazy {
|
||||
@ -70,7 +79,8 @@ object DefaultKryoCustomizer {
|
||||
instantiatorStrategy = CustomInstantiatorStrategy()
|
||||
|
||||
// Required for HashCheckingStream (de)serialization.
|
||||
// Note that return type should be specifically set to InputStream, otherwise it may not work, i.e. val aStream : InputStream = HashCheckingStream(...).
|
||||
// Note that return type should be specifically set to InputStream, otherwise it may not work,
|
||||
// i.e. val aStream : InputStream = HashCheckingStream(...).
|
||||
addDefaultSerializer(InputStream::class.java, InputStreamSerializer)
|
||||
addDefaultSerializer(SerializeAsToken::class.java, SerializeAsTokenSerializer<SerializeAsToken>())
|
||||
addDefaultSerializer(Logger::class.java, LoggerSerializer)
|
||||
@ -79,8 +89,10 @@ object DefaultKryoCustomizer {
|
||||
// WARNING: reordering the registrations here will cause a change in the serialized form, since classes
|
||||
// with custom serializers get written as registration ids. This will break backwards-compatibility.
|
||||
// Please add any new registrations to the end.
|
||||
// TODO: re-organise registrations into logical groups before v1.0
|
||||
|
||||
addDefaultSerializer(LinkedHashMapIteratorSerializer.getIterator()::class.java.superclass, LinkedHashMapIteratorSerializer)
|
||||
register(LinkedHashMapEntrySerializer.getEntry()::class.java, LinkedHashMapEntrySerializer)
|
||||
register(LinkedListItrSerializer.getListItr()::class.java, LinkedListItrSerializer)
|
||||
register(Arrays.asList("").javaClass, ArraysAsListSerializer())
|
||||
register(LazyMappedList::class.java, LazyMappedListSerializer)
|
||||
register(SignedTransaction::class.java, SignedTransactionSerializer)
|
||||
@ -129,6 +141,10 @@ object DefaultKryoCustomizer {
|
||||
register(ContractUpgradeWireTransaction::class.java, ContractUpgradeWireTransactionSerializer)
|
||||
register(ContractUpgradeFilteredTransaction::class.java, ContractUpgradeFilteredTransactionSerializer)
|
||||
|
||||
addDefaultSerializer(Iterator::class.java) {kryo, type ->
|
||||
IteratorSerializer(type, CompatibleFieldSerializer<Iterator<*>>(kryo, type).apply { setIgnoreSyntheticFields(false) })
|
||||
}
|
||||
|
||||
for (whitelistProvider in serializationWhitelists) {
|
||||
val types = whitelistProvider.whitelist
|
||||
require(types.toSet().size == types.size) {
|
||||
|
@ -0,0 +1,52 @@
|
||||
package net.corda.nodeapi.internal.serialization.kryo
|
||||
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.esotericsoftware.kryo.Serializer
|
||||
import com.esotericsoftware.kryo.io.Input
|
||||
import com.esotericsoftware.kryo.io.Output
|
||||
import java.lang.reflect.Field
|
||||
|
||||
class IteratorSerializer(type: Class<*>, private val serializer: Serializer<Iterator<*>>) : Serializer<Iterator<*>>(false, false) {
|
||||
|
||||
private val iterableReferenceField = findField(type, "this\$0")?.apply { isAccessible = true }
|
||||
private val expectedModCountField = findField(type, "expectedModCount")?.apply { isAccessible = true }
|
||||
private val iterableReferenceFieldType = iterableReferenceField?.type
|
||||
private val modCountField = when (iterableReferenceFieldType) {
|
||||
null -> null
|
||||
else -> findField(iterableReferenceFieldType, "modCount")?.apply { isAccessible = true }
|
||||
}
|
||||
|
||||
override fun write(kryo: Kryo, output: Output, obj: Iterator<*>) {
|
||||
serializer.write(kryo, output, obj)
|
||||
}
|
||||
|
||||
override fun read(kryo: Kryo, input: Input, type: Class<Iterator<*>>): Iterator<*> {
|
||||
val iterator = serializer.read(kryo, input, type)
|
||||
return fixIterator(iterator)
|
||||
}
|
||||
|
||||
private fun fixIterator(iterator: Iterator<*>) : Iterator<*> {
|
||||
|
||||
// Set expectedModCount of iterator
|
||||
val iterableInstance = iterableReferenceField?.get(iterator) ?: return iterator
|
||||
val modCountValue = modCountField?.getInt(iterableInstance) ?: return iterator
|
||||
expectedModCountField?.setInt(iterator, modCountValue)
|
||||
|
||||
return iterator
|
||||
}
|
||||
|
||||
/**
|
||||
* Find field in clazz or any superclass
|
||||
*/
|
||||
private fun findField(clazz: Class<*>, fieldName: String): Field? {
|
||||
return clazz.declaredFields.firstOrNull { x -> x.name == fieldName } ?: when {
|
||||
clazz.superclass != null -> {
|
||||
// Look in superclasses
|
||||
findField(clazz.superclass, fieldName)
|
||||
}
|
||||
else -> null // Not found
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -18,6 +18,7 @@ import net.corda.core.serialization.SerializeAsTokenContext
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.transactions.*
|
||||
import net.corda.core.utilities.OpaqueBytes
|
||||
import net.corda.core.utilities.SgxSupport
|
||||
import net.corda.serialization.internal.serializationContextKey
|
||||
import org.slf4j.Logger
|
||||
import org.slf4j.LoggerFactory
|
||||
@ -67,13 +68,17 @@ object SerializedBytesSerializer : Serializer<SerializedBytes<Any>>() {
|
||||
* set via the constructor and the class is immutable.
|
||||
*/
|
||||
class ImmutableClassSerializer<T : Any>(val klass: KClass<T>) : Serializer<T>() {
|
||||
val props = klass.memberProperties.sortedBy { it.name }
|
||||
val propsByName = props.associateBy { it.name }
|
||||
val constructor = klass.primaryConstructor!!
|
||||
val props by lazy { klass.memberProperties.sortedBy { it.name } }
|
||||
val propsByName by lazy { props.associateBy { it.name } }
|
||||
val constructor by lazy { klass.primaryConstructor!! }
|
||||
|
||||
init {
|
||||
props.forEach {
|
||||
require(it !is KMutableProperty<*>) { "$it mutable property of class: ${klass} is unsupported" }
|
||||
// Verify that this class is immutable (all properties are final).
|
||||
// We disable this check inside SGX as the reflection blows up.
|
||||
if (!SgxSupport.isInsideEnclave) {
|
||||
props.forEach {
|
||||
require(it !is KMutableProperty<*>) { "$it mutable property of class: ${klass} is unsupported" }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -10,11 +10,20 @@ import com.esotericsoftware.kryo.io.Output
|
||||
import com.esotericsoftware.kryo.pool.KryoPool
|
||||
import com.esotericsoftware.kryo.serializers.ClosureSerializer
|
||||
import net.corda.core.internal.uncheckedCast
|
||||
import net.corda.core.serialization.*
|
||||
import net.corda.core.serialization.ClassWhitelist
|
||||
import net.corda.core.serialization.SerializationDefaults
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.serialization.internal.CheckpointSerializationContext
|
||||
import net.corda.core.serialization.internal.CheckpointSerializer
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.serialization.internal.*
|
||||
import net.corda.serialization.internal.AlwaysAcceptEncodingWhitelist
|
||||
import net.corda.serialization.internal.ByteBufferInputStream
|
||||
import net.corda.serialization.internal.CheckpointSerializationContextImpl
|
||||
import net.corda.serialization.internal.CordaSerializationEncoding
|
||||
import net.corda.serialization.internal.CordaSerializationMagic
|
||||
import net.corda.serialization.internal.QuasarWhitelist
|
||||
import net.corda.serialization.internal.SectionId
|
||||
import net.corda.serialization.internal.encodingNotPermittedFormat
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
val kryoMagic = CordaSerializationMagic("corda".toByteArray() + byteArrayOf(0, 0))
|
||||
|
@ -0,0 +1,2 @@
|
||||
net.corda.nodeapi.internal.persistence.factory.H2SessionFactoryFactory
|
||||
net.corda.nodeapi.internal.persistence.factory.PostgresSessionFactoryFactory
|
@ -9,7 +9,6 @@ import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.internal.div
|
||||
import net.corda.core.utilities.NetworkHostAndPort
|
||||
import org.assertj.core.api.Assertions.*
|
||||
import org.hibernate.exception.DataException
|
||||
import org.junit.Test
|
||||
import java.net.URL
|
||||
import java.nio.file.Path
|
||||
|
@ -2,6 +2,7 @@ package net.corda.nodeapi.internal.cryptoservice.bouncycastle
|
||||
|
||||
import net.corda.core.crypto.Crypto
|
||||
import net.corda.core.crypto.SignatureScheme
|
||||
import net.corda.core.crypto.internal.cordaBouncyCastleProvider
|
||||
import net.corda.core.internal.div
|
||||
import net.corda.core.utilities.days
|
||||
import net.corda.nodeapi.internal.config.CertificateStoreSupplier
|
||||
@ -13,6 +14,7 @@ import net.corda.nodeapi.internal.cryptoservice.WrappedPrivateKey
|
||||
import net.corda.nodeapi.internal.cryptoservice.WrappingMode
|
||||
import net.corda.testing.core.ALICE_NAME
|
||||
import net.corda.coretesting.internal.stubs.CertificateStoreStubs
|
||||
import net.corda.nodeapi.internal.crypto.loadOrCreateKeyStore
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||
import org.bouncycastle.jce.provider.BouncyCastleProvider
|
||||
@ -23,8 +25,10 @@ import org.junit.rules.TemporaryFolder
|
||||
import java.io.FileOutputStream
|
||||
import java.nio.file.Path
|
||||
import java.security.*
|
||||
import java.security.spec.ECGenParameterSpec
|
||||
import java.time.Duration
|
||||
import java.util.*
|
||||
import javax.crypto.Cipher
|
||||
import javax.security.auth.x500.X500Principal
|
||||
import kotlin.test.assertFailsWith
|
||||
import kotlin.test.assertFalse
|
||||
@ -59,7 +63,8 @@ class BCCryptoServiceTests {
|
||||
fun `BCCryptoService generate key pair and sign both data and cert`() {
|
||||
val cryptoService = BCCryptoService(ALICE_NAME.x500Principal, signingCertificateStore, wrappingKeyStorePath)
|
||||
// Testing every supported scheme.
|
||||
Crypto.supportedSignatureSchemes().filter { it != Crypto.COMPOSITE_KEY }.forEach { generateKeyAndSignForScheme(cryptoService, it) }
|
||||
Crypto.supportedSignatureSchemes().filter { it != Crypto.COMPOSITE_KEY
|
||||
&& it.signatureName != "SHA512WITHSPHINCS256"}.forEach { generateKeyAndSignForScheme(cryptoService, it) }
|
||||
}
|
||||
|
||||
private fun generateKeyAndSignForScheme(cryptoService: BCCryptoService, signatureScheme: SignatureScheme) {
|
||||
@ -252,4 +257,27 @@ class BCCryptoServiceTests {
|
||||
|
||||
Crypto.doVerify(publicKey, signature, data)
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `cryptoService can sign with previously encoded version of wrapped key`() {
|
||||
val cryptoService = BCCryptoService(ALICE_NAME.x500Principal, signingCertificateStore, wrappingKeyStorePath)
|
||||
|
||||
val wrappingKeyAlias = UUID.randomUUID().toString()
|
||||
cryptoService.createWrappingKey(wrappingKeyAlias)
|
||||
|
||||
val wrappingKeyStore = loadOrCreateKeyStore(wrappingKeyStorePath, cryptoService.certificateStore.password, "PKCS12")
|
||||
val wrappingKey = wrappingKeyStore.getKey(wrappingKeyAlias, cryptoService.certificateStore.entryPassword.toCharArray())
|
||||
val cipher = Cipher.getInstance("AES", cordaBouncyCastleProvider)
|
||||
cipher.init(Cipher.WRAP_MODE, wrappingKey)
|
||||
|
||||
val keyPairGenerator = KeyPairGenerator.getInstance("EC", cordaBouncyCastleProvider)
|
||||
keyPairGenerator.initialize(ECGenParameterSpec("secp256r1"))
|
||||
val keyPair = keyPairGenerator.generateKeyPair()
|
||||
val privateKeyMaterialWrapped = cipher.wrap(keyPair.private)
|
||||
val wrappedPrivateKey = WrappedPrivateKey(privateKeyMaterialWrapped, Crypto.ECDSA_SECP256R1_SHA256, encodingVersion = null)
|
||||
|
||||
val data = "data".toByteArray()
|
||||
val signature = cryptoService.sign(wrappingKeyAlias, wrappedPrivateKey, data)
|
||||
Crypto.doVerify(keyPair.public, signature, data)
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,26 @@
|
||||
package net.corda.nodeapi.internal.persistence
|
||||
|
||||
import com.nhaarman.mockito_kotlin.mock
|
||||
import net.corda.core.internal.NamedCacheFactory
|
||||
import org.junit.Assert
|
||||
import org.junit.Test
|
||||
|
||||
class HibernateConfigurationFactoryLoadingTest {
|
||||
@Test(timeout=300_000)
|
||||
fun checkErrorMessageForMissingFactory() {
|
||||
val jdbcUrl = "jdbc:madeUpNonense:foobar.com:1234"
|
||||
val presentFactories = listOf("H2", "PostgreSQL")
|
||||
try {
|
||||
val cacheFactory = mock<NamedCacheFactory>()
|
||||
HibernateConfiguration(
|
||||
emptySet(),
|
||||
DatabaseConfig(),
|
||||
emptyList(),
|
||||
jdbcUrl,
|
||||
cacheFactory)
|
||||
Assert.fail("Expected exception not thrown")
|
||||
} catch (e: HibernateConfigException) {
|
||||
Assert.assertEquals("Failed to find a SessionFactoryFactory to handle $jdbcUrl - factories present for ${presentFactories}", e.message)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,104 @@
|
||||
package net.corda.nodeapi.internal.persistence
|
||||
|
||||
import com.nhaarman.mockito_kotlin.mock
|
||||
import org.junit.Test
|
||||
import java.sql.Connection
|
||||
import java.sql.Savepoint
|
||||
|
||||
class RestrictedConnectionTest {
|
||||
|
||||
private val connection : Connection = mock()
|
||||
private val savePoint : Savepoint = mock()
|
||||
private val restrictedConnection : RestrictedConnection = RestrictedConnection(connection)
|
||||
|
||||
companion object {
|
||||
private const val TEST_STRING : String = "test"
|
||||
private const val TEST_INT : Int = 1
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testAbort() {
|
||||
restrictedConnection.abort { println("I'm just an executor for this test...") }
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testClearWarnings() {
|
||||
restrictedConnection.clearWarnings()
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testClose() {
|
||||
restrictedConnection.close()
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testCommit() {
|
||||
restrictedConnection.commit()
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testSetSavepoint() {
|
||||
restrictedConnection.setSavepoint()
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testSetSavepointWithName() {
|
||||
restrictedConnection.setSavepoint(TEST_STRING)
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testReleaseSavepoint() {
|
||||
restrictedConnection.releaseSavepoint(savePoint)
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testRollback() {
|
||||
restrictedConnection.rollback()
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testRollbackWithSavepoint() {
|
||||
restrictedConnection.rollback(savePoint)
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testSetCatalog() {
|
||||
restrictedConnection.catalog = TEST_STRING
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testSetTransactionIsolation() {
|
||||
restrictedConnection.transactionIsolation = TEST_INT
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testSetTypeMap() {
|
||||
val map: MutableMap<String, Class<*>> = mutableMapOf()
|
||||
restrictedConnection.typeMap = map
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testSetHoldability() {
|
||||
restrictedConnection.holdability = TEST_INT
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testSetSchema() {
|
||||
restrictedConnection.schema = TEST_STRING
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testSetNetworkTimeout() {
|
||||
restrictedConnection.setNetworkTimeout({ println("I'm just an executor for this test...") }, TEST_INT)
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testSetAutoCommit() {
|
||||
restrictedConnection.autoCommit = true
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testSetReadOnly() {
|
||||
restrictedConnection.isReadOnly = true
|
||||
}
|
||||
}
|
@ -0,0 +1,58 @@
|
||||
package net.corda.nodeapi.internal.persistence
|
||||
|
||||
import com.nhaarman.mockito_kotlin.doReturn
|
||||
import com.nhaarman.mockito_kotlin.mock
|
||||
import com.nhaarman.mockito_kotlin.whenever
|
||||
import org.junit.Test
|
||||
import javax.persistence.EntityManager
|
||||
import javax.persistence.EntityTransaction
|
||||
import javax.persistence.LockModeType
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class RestrictedEntityManagerTest {
|
||||
private val entitymanager = mock<EntityManager>()
|
||||
private val transaction = mock<EntityTransaction>()
|
||||
private val restrictedEntityManager = RestrictedEntityManager(entitymanager)
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testClose() {
|
||||
restrictedEntityManager.close()
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun testClear() {
|
||||
restrictedEntityManager.clear()
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testGetMetaModel() {
|
||||
restrictedEntityManager.getMetamodel()
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun testGetTransaction() {
|
||||
whenever(entitymanager.transaction).doReturn(transaction)
|
||||
assertTrue(restrictedEntityManager.transaction is RestrictedEntityTransaction)
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testJoinTransaction() {
|
||||
restrictedEntityManager.joinTransaction()
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testLockWithTwoParameters() {
|
||||
restrictedEntityManager.lock(Object(), LockModeType.OPTIMISTIC)
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testLockWithThreeParameters() {
|
||||
val map: MutableMap<String,Any> = mutableMapOf()
|
||||
restrictedEntityManager.lock(Object(), LockModeType.OPTIMISTIC,map)
|
||||
}
|
||||
|
||||
@Test(expected = UnsupportedOperationException::class, timeout=300_000)
|
||||
fun testSetProperty() {
|
||||
restrictedEntityManager.setProperty("number", 12)
|
||||
}
|
||||
}
|
@ -4,7 +4,10 @@ import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.utilities.NetworkHostAndPort
|
||||
import net.corda.coretesting.internal.configureTestSSL
|
||||
import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
|
||||
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
|
||||
import net.corda.nodeapi.internal.config.CertificateStore
|
||||
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
|
||||
import org.junit.Test
|
||||
import javax.net.ssl.KeyManagerFactory
|
||||
import javax.net.ssl.SNIHostName
|
||||
@ -23,7 +26,7 @@ class SSLHelperTest {
|
||||
val keyStore = sslConfig.keyStore
|
||||
keyManagerFactory.init(CertificateStore.fromFile(keyStore.path, keyStore.storePassword, keyStore.entryPassword, false))
|
||||
val trustStore = sslConfig.trustStore
|
||||
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(CertificateStore.fromFile(trustStore.path, trustStore.storePassword, trustStore.entryPassword, false), false))
|
||||
trustManagerFactory.init(initialiseTrustStoreAndEnableCrlChecking(CertificateStore.fromFile(trustStore.path, trustStore.storePassword, trustStore.entryPassword, false), RevocationConfigImpl(RevocationConfig.Mode.HARD_FAIL)))
|
||||
|
||||
val sslHandler = createClientSslHelper(NetworkHostAndPort("localhost", 1234), setOf(legalName), keyManagerFactory, trustManagerFactory)
|
||||
val legalNameHash = SecureHash.sha256(legalName.toString()).toString().take(32).toLowerCase()
|
||||
@ -34,4 +37,14 @@ class SSLHelperTest {
|
||||
assertEquals(1, sslHandler.engine().sslParameters.serverNames.size)
|
||||
assertEquals("$legalNameHash.corda.net", (sslHandler.engine().sslParameters.serverNames.first() as SNIHostName).asciiName)
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `test distributionPointsToString`() {
|
||||
val certStore = CertificateStore.fromResource(
|
||||
"net/corda/nodeapi/internal/protonwrapper/netty/sslkeystore_Revoked.jks",
|
||||
DEV_CA_KEY_STORE_PASS, DEV_CA_PRIVATE_KEY_PASS)
|
||||
val distPoints = certStore.query { getCertificateChain(CORDA_CLIENT_TLS).map { it.distributionPointsToString() } }
|
||||
assertEquals(listOf("NO CRLDP ext", "http://day-v3-doorman.cordaconnect.io/doorman",
|
||||
"http://day3-doorman.cordaconnect.io/doorman", "http://day3-doorman.cordaconnect.io/subordinate", "NO CRLDP ext"), distPoints)
|
||||
}
|
||||
}
|
@ -0,0 +1,56 @@
|
||||
package net.corda.nodeapi.internal.protonwrapper.netty.revocation
|
||||
|
||||
import net.corda.core.utilities.Try
|
||||
import net.corda.nodeapi.internal.DEV_CA_KEY_STORE_PASS
|
||||
import net.corda.nodeapi.internal.DEV_CA_PRIVATE_KEY_PASS
|
||||
import net.corda.nodeapi.internal.config.CertificateStore
|
||||
import net.corda.nodeapi.internal.crypto.X509Utilities
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.ExternalCrlSource
|
||||
import org.bouncycastle.jcajce.provider.asymmetric.x509.CertificateFactory
|
||||
import org.junit.Test
|
||||
import java.math.BigInteger
|
||||
|
||||
import java.security.cert.X509CRL
|
||||
import java.security.cert.X509Certificate
|
||||
import java.sql.Date
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class ExternalSourceRevocationCheckerTest {
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun checkRevoked() {
|
||||
val checkResult = performCheckOnDate(Date.valueOf("2019-09-27"))
|
||||
val failedChecks = checkResult.filterNot { it.second.isSuccess }
|
||||
assertEquals(1, failedChecks.size)
|
||||
assertEquals(BigInteger.valueOf(8310484079152632582), failedChecks.first().first.serialNumber)
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun checkTooEarly() {
|
||||
val checkResult = performCheckOnDate(Date.valueOf("2019-08-27"))
|
||||
assertTrue(checkResult.all { it.second.isSuccess })
|
||||
}
|
||||
|
||||
private fun performCheckOnDate(date: Date): List<Pair<X509Certificate, Try<Unit>>> {
|
||||
val certStore = CertificateStore.fromResource(
|
||||
"net/corda/nodeapi/internal/protonwrapper/netty/sslkeystore_Revoked.jks",
|
||||
DEV_CA_KEY_STORE_PASS, DEV_CA_PRIVATE_KEY_PASS)
|
||||
|
||||
val resourceAsStream = javaClass.getResourceAsStream("/net/corda/nodeapi/internal/protonwrapper/netty/doorman.crl")
|
||||
val crl = CertificateFactory().engineGenerateCRL(resourceAsStream) as X509CRL
|
||||
|
||||
//val crlHolder = X509CRLHolder(resourceAsStream)
|
||||
//crlHolder.revokedCertificates as X509CRLEntryHolder
|
||||
|
||||
val instance = ExternalSourceRevocationChecker(object : ExternalCrlSource {
|
||||
override fun fetch(certificate: X509Certificate): Set<X509CRL> = setOf(crl)
|
||||
}) { date }
|
||||
|
||||
return certStore.query {
|
||||
getCertificateChain(X509Utilities.CORDA_CLIENT_TLS).map {
|
||||
Pair(it, Try.on { instance.check(it, mutableListOf()) })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,122 @@
|
||||
package net.corda.nodeapi.internal.serialization.kryo
|
||||
|
||||
import com.nhaarman.mockito_kotlin.doReturn
|
||||
import com.nhaarman.mockito_kotlin.whenever
|
||||
import net.corda.core.serialization.EncodingWhitelist
|
||||
import net.corda.core.serialization.internal.CheckpointSerializationContext
|
||||
import net.corda.core.serialization.internal.checkpointDeserialize
|
||||
import net.corda.core.serialization.internal.checkpointSerialize
|
||||
import net.corda.coretesting.internal.rigorousMock
|
||||
import net.corda.serialization.internal.AllWhitelist
|
||||
import net.corda.serialization.internal.CheckpointSerializationContextImpl
|
||||
import net.corda.serialization.internal.CordaSerializationEncoding
|
||||
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Before
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
import org.junit.runners.Parameterized
|
||||
import org.junit.runners.Parameterized.Parameters
|
||||
import java.util.*
|
||||
import kotlin.collections.ArrayList
|
||||
import kotlin.collections.HashMap
|
||||
import kotlin.collections.HashSet
|
||||
import kotlin.collections.LinkedHashMap
|
||||
import kotlin.collections.LinkedHashSet
|
||||
|
||||
@RunWith(Parameterized::class)
|
||||
class ArrayListItrConcurrentModificationException(private val compression: CordaSerializationEncoding?) {
|
||||
companion object {
|
||||
@Parameters(name = "{0}")
|
||||
@JvmStatic
|
||||
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
|
||||
}
|
||||
|
||||
@get:Rule
|
||||
val serializationRule = CheckpointSerializationEnvironmentRule(inheritable = true)
|
||||
private lateinit var context: CheckpointSerializationContext
|
||||
|
||||
@Before
|
||||
fun setup() {
|
||||
context = CheckpointSerializationContextImpl(
|
||||
deserializationClassLoader = javaClass.classLoader,
|
||||
whitelist = AllWhitelist,
|
||||
properties = emptyMap(),
|
||||
objectReferencesEnabled = true,
|
||||
encoding = compression,
|
||||
encodingWhitelist = rigorousMock<EncodingWhitelist>().also {
|
||||
if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression)
|
||||
})
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `ArrayList iterator can checkpoint without error`() {
|
||||
runTestWithCollection(ArrayList())
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `HashSet iterator can checkpoint without error`() {
|
||||
runTestWithCollection(HashSet())
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `LinkedHashSet iterator can checkpoint without error`() {
|
||||
runTestWithCollection(LinkedHashSet())
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `HashMap iterator can checkpoint without error`() {
|
||||
runTestWithCollection(HashMap())
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `LinkedHashMap iterator can checkpoint without error`() {
|
||||
runTestWithCollection(LinkedHashMap())
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `LinkedList iterator can checkpoint without error`() {
|
||||
runTestWithCollection(LinkedList())
|
||||
}
|
||||
|
||||
private data class TestCheckpoint<C,I>(val list: C, val iterator: I)
|
||||
|
||||
private fun runTestWithCollection(collection: MutableCollection<Int>) {
|
||||
|
||||
for (i in 1..100) {
|
||||
collection.add(i)
|
||||
}
|
||||
|
||||
val iterator = collection.iterator()
|
||||
iterator.next()
|
||||
|
||||
val checkpoint = TestCheckpoint(collection, iterator)
|
||||
|
||||
val serializedBytes = checkpoint.checkpointSerialize(context)
|
||||
val deserializedCheckpoint = serializedBytes.checkpointDeserialize(context)
|
||||
|
||||
assertThat(deserializedCheckpoint.list).isEqualTo(collection)
|
||||
assertThat(deserializedCheckpoint.iterator.next()).isEqualTo(2)
|
||||
assertThat(deserializedCheckpoint.iterator.hasNext()).isTrue()
|
||||
}
|
||||
|
||||
private fun runTestWithCollection(collection: MutableMap<Int, Int>) {
|
||||
|
||||
for (i in 1..100) {
|
||||
collection[i] = i
|
||||
}
|
||||
|
||||
val iterator = collection.iterator()
|
||||
iterator.next()
|
||||
|
||||
val checkpoint = TestCheckpoint(collection, iterator)
|
||||
|
||||
val serializedBytes = checkpoint.checkpointSerialize(context)
|
||||
val deserializedCheckpoint = serializedBytes.checkpointDeserialize(context)
|
||||
|
||||
assertThat(deserializedCheckpoint.list).isEqualTo(collection)
|
||||
assertThat(deserializedCheckpoint.iterator.next().key).isEqualTo(2)
|
||||
assertThat(deserializedCheckpoint.iterator.hasNext()).isTrue()
|
||||
}
|
||||
}
|
@ -0,0 +1,171 @@
|
||||
package net.corda.nodeapi.internal.serialization.kryo
|
||||
|
||||
import org.junit.Ignore
|
||||
import org.junit.Test
|
||||
import org.junit.jupiter.api.assertDoesNotThrow
|
||||
import java.util.LinkedList
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class KryoCheckpointTest {
|
||||
|
||||
private val testSize = 1000L
|
||||
|
||||
/**
|
||||
* This test just ensures that the checkpoints still work in light of [LinkedHashMapEntrySerializer].
|
||||
*/
|
||||
@Test(timeout=300_000)
|
||||
fun `linked hash map can checkpoint without error`() {
|
||||
var lastKey = ""
|
||||
val dummyMap = linkedMapOf<String, Long>()
|
||||
for (i in 0..testSize) {
|
||||
dummyMap[i.toString()] = i
|
||||
}
|
||||
var it = dummyMap.iterator()
|
||||
while (it.hasNext()) {
|
||||
lastKey = it.next().key
|
||||
val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
|
||||
it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
}
|
||||
assertEquals(testSize.toString(), lastKey)
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `empty linked hash map can checkpoint without error`() {
|
||||
val dummyMap = linkedMapOf<String, Long>()
|
||||
val it = dummyMap.iterator()
|
||||
val itKeys = dummyMap.keys.iterator()
|
||||
val itValues = dummyMap.values.iterator()
|
||||
val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
|
||||
val bytesKeys = KryoCheckpointSerializer.serialize(itKeys, KRYO_CHECKPOINT_CONTEXT)
|
||||
val bytesValues = KryoCheckpointSerializer.serialize(itValues, KRYO_CHECKPOINT_CONTEXT)
|
||||
assertDoesNotThrow {
|
||||
KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
KryoCheckpointSerializer.deserialize(bytesKeys, itKeys.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
KryoCheckpointSerializer.deserialize(bytesValues, itValues.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `linked hash map with null values can checkpoint without error`() {
|
||||
val dummyMap = linkedMapOf<String?, Long?>().apply {
|
||||
put("foo", 2L)
|
||||
put(null, null)
|
||||
put("bar", 3L)
|
||||
}
|
||||
val it = dummyMap.iterator()
|
||||
val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
|
||||
|
||||
val itKeys = dummyMap.keys.iterator()
|
||||
itKeys.next()
|
||||
itKeys.next()
|
||||
val bytesKeys = KryoCheckpointSerializer.serialize(itKeys, KRYO_CHECKPOINT_CONTEXT)
|
||||
|
||||
val itValues = dummyMap.values.iterator()
|
||||
val bytesValues = KryoCheckpointSerializer.serialize(itValues, KRYO_CHECKPOINT_CONTEXT)
|
||||
|
||||
assertDoesNotThrow {
|
||||
KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
val desItKeys = KryoCheckpointSerializer.deserialize(bytesKeys, itKeys.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
assertEquals("bar", desItKeys.next())
|
||||
KryoCheckpointSerializer.deserialize(bytesValues, itValues.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `linked hash map keys can checkpoint without error`() {
|
||||
var lastKey = ""
|
||||
val dummyMap = linkedMapOf<String, Long>()
|
||||
for (i in 0..testSize) {
|
||||
dummyMap[i.toString()] = i
|
||||
}
|
||||
var it = dummyMap.keys.iterator()
|
||||
while (it.hasNext()) {
|
||||
lastKey = it.next()
|
||||
val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
|
||||
it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
}
|
||||
assertEquals(testSize.toString(), lastKey)
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `linked hash map values can checkpoint without error`() {
|
||||
var lastValue = 0L
|
||||
val dummyMap = linkedMapOf<String, Long>()
|
||||
for (i in 0..testSize) {
|
||||
dummyMap[i.toString()] = i
|
||||
}
|
||||
var it = dummyMap.values.iterator()
|
||||
while (it.hasNext()) {
|
||||
lastValue = it.next()
|
||||
val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
|
||||
it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
}
|
||||
assertEquals(testSize, lastValue)
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `linked hash map values can checkpoint without error, even with repeats`() {
|
||||
var lastValue = "0"
|
||||
val dummyMap = linkedMapOf<String, String>()
|
||||
for (i in 0..testSize) {
|
||||
dummyMap[i.toString()] = (i % 10).toString()
|
||||
}
|
||||
var it = dummyMap.values.iterator()
|
||||
while (it.hasNext()) {
|
||||
lastValue = it.next()
|
||||
val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
|
||||
it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
}
|
||||
assertEquals((testSize % 10).toString(), lastValue)
|
||||
}
|
||||
|
||||
@Ignore("Kryo optimizes boxed primitives so this does not work. Need to customise ReferenceResolver to stop it doing it.")
|
||||
@Test(timeout = 300_000)
|
||||
fun `linked hash map values can checkpoint without error, even with repeats for boxed primitives`() {
|
||||
var lastValue = 0L
|
||||
val dummyMap = linkedMapOf<String, Long>()
|
||||
for (i in 0..testSize) {
|
||||
dummyMap[i.toString()] = (i % 10)
|
||||
}
|
||||
var it = dummyMap.values.iterator()
|
||||
while (it.hasNext()) {
|
||||
lastValue = it.next()
|
||||
val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
|
||||
it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
}
|
||||
assertEquals(testSize % 10, lastValue)
|
||||
}
|
||||
|
||||
/**
|
||||
* This test just ensures that the checkpoints still work in light of [LinkedHashMapEntrySerializer].
|
||||
*/
|
||||
@Test(timeout=300_000)
|
||||
fun `linked hash set can checkpoint without error`() {
|
||||
var result: Any = 0L
|
||||
val dummySet = linkedSetOf<Any>().apply { addAll(0..testSize) }
|
||||
var it = dummySet.iterator()
|
||||
while (it.hasNext()) {
|
||||
result = it.next()
|
||||
val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
|
||||
it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
}
|
||||
assertEquals(testSize, result)
|
||||
}
|
||||
|
||||
/**
|
||||
* This test just ensures that the checkpoints still work in light of [LinkedListItrSerializer].
|
||||
*/
|
||||
@Test(timeout=300_000)
|
||||
fun `linked list can checkpoint without error`() {
|
||||
var result: Any = 0L
|
||||
val dummyList = LinkedList<Long>().apply { addAll(0..testSize) }
|
||||
|
||||
var it = dummyList.iterator()
|
||||
while (it.hasNext()) {
|
||||
result = it.next()
|
||||
val bytes = KryoCheckpointSerializer.serialize(it, KRYO_CHECKPOINT_CONTEXT)
|
||||
it = KryoCheckpointSerializer.deserialize(bytes, it.javaClass, KRYO_CHECKPOINT_CONTEXT)
|
||||
}
|
||||
assertEquals(testSize, result)
|
||||
}
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
Represents some test data which contains real certificates produced by DayWatch Doorman as well as CRL list file.
|
||||
|
||||
For all the keystores the password is "cordacadevpass".
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user