mirror of
https://github.com/corda/corda.git
synced 2025-04-20 17:11:26 +00:00
commit
26af54412b
@ -73,7 +73,7 @@ are compatible with TLS 1.2, while the default scheme per key type is also shown
|
||||
| | and SHA-256 | | vendors. | | - tls |
|
||||
| | | | - network map (CN) |
|
||||
+-------------------------+---------------------------------------------------------------+-----+-------------------------+
|
||||
| | ECDSA using the | | secp256k1 is the curve adopted by Bitcoin and as such there | YES | |
|
||||
| | ECDSA using the | | secp256k1 is the curve adopted by Bitcoin and as such there | NO | |
|
||||
| | Koblitz k1 curve | | is a wealth of infrastructure, code and advanced algorithms | | |
|
||||
| | (secp256k1) | | designed for use with it. This curve is standardised by | | |
|
||||
| | and SHA-256 | | NIST as part of the "Suite B" cryptographic algorithms and | | |
|
||||
|
@ -31,14 +31,14 @@ class NodeKeystoreCheckTest : IntegrationTest() {
|
||||
driver(DriverParameters(startNodesInProcess = true, notarySpecs = emptyList())) {
|
||||
assertThatThrownBy {
|
||||
startNode(customOverrides = mapOf("devMode" to false)).getOrThrow()
|
||||
}.hasMessageContaining("Identity certificate not found")
|
||||
}.hasMessageContaining("One or more keyStores (identity or TLS) or trustStore not found.")
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `node should throw exception if cert path doesn't chain to the trust root`() {
|
||||
driver(DriverParameters(startNodesInProcess = true, notarySpecs = emptyList())) {
|
||||
// Create keystores
|
||||
// Create keystores.
|
||||
val keystorePassword = "password"
|
||||
val certificatesDirectory = baseDirectory(ALICE_NAME) / "certificates"
|
||||
val signingCertStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory, keystorePassword)
|
||||
@ -57,7 +57,7 @@ class NodeKeystoreCheckTest : IntegrationTest() {
|
||||
|
||||
// Fiddle with node keystore.
|
||||
signingCertStore.get().update {
|
||||
// Self signed root
|
||||
// Self signed root.
|
||||
val badRootKeyPair = Crypto.generateKeyPair()
|
||||
val badRoot = X509Utilities.createSelfSignedCACertificate(X500Principal("O=Bad Root,L=Lodnon,C=GB"), badRootKeyPair)
|
||||
val nodeCA = getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA)
|
||||
|
@ -18,13 +18,13 @@ import net.corda.nodeapi.internal.ArtemisMessagingClient
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
|
||||
import net.corda.nodeapi.internal.ArtemisTcpTransport
|
||||
import net.corda.nodeapi.internal.config.MutualSslConfiguration
|
||||
import net.corda.nodeapi.internal.registerDevP2pCertificates
|
||||
import net.corda.nodeapi.internal.crypto.*
|
||||
import net.corda.nodeapi.internal.crypto.X509Utilities
|
||||
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
|
||||
import net.corda.nodeapi.internal.protonwrapper.netty.init
|
||||
import net.corda.nodeapi.internal.registerDevP2pCertificates
|
||||
import net.corda.nodeapi.internal.registerDevSigningCertificates
|
||||
import net.corda.testing.core.ALICE_NAME
|
||||
import net.corda.testing.core.BOB_NAME
|
||||
@ -489,7 +489,10 @@ class ProtonWrapperTests {
|
||||
sharedThreadPool = sharedEventGroup)
|
||||
}
|
||||
|
||||
private fun createServer(port: Int, name: CordaX500Name = ALICE_NAME, maxMessageSize: Int = MAX_MESSAGE_SIZE, crlCheckSoftFail: Boolean = true): AMQPServer {
|
||||
private fun createServer(port: Int,
|
||||
name: CordaX500Name = ALICE_NAME,
|
||||
maxMessageSize: Int = MAX_MESSAGE_SIZE,
|
||||
crlCheckSoftFail: Boolean = true): AMQPServer {
|
||||
val baseDirectory = temporaryFolder.root.toPath() / "server"
|
||||
val certificatesDirectory = baseDirectory / "certificates"
|
||||
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)
|
||||
|
@ -72,7 +72,11 @@ import net.corda.node.utilities.NodeBuildProperties
|
||||
import net.corda.nodeapi.internal.DevIdentityGenerator
|
||||
import net.corda.nodeapi.internal.NodeInfoAndSigned
|
||||
import net.corda.nodeapi.internal.SignedNodeInfo
|
||||
import net.corda.nodeapi.internal.config.CertificateStore
|
||||
import net.corda.nodeapi.internal.crypto.X509Utilities
|
||||
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_CA
|
||||
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
|
||||
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_ROOT_CA
|
||||
import net.corda.nodeapi.internal.crypto.X509Utilities.DISTRIBUTED_NOTARY_ALIAS_PREFIX
|
||||
import net.corda.nodeapi.internal.crypto.X509Utilities.NODE_IDENTITY_ALIAS_PREFIX
|
||||
import net.corda.nodeapi.internal.persistence.*
|
||||
@ -248,20 +252,20 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
|
||||
return proxies.fold(ops) { delegate, decorate -> decorate(delegate) }
|
||||
}
|
||||
|
||||
private fun initKeyStore(): X509Certificate {
|
||||
private fun initKeyStores(): X509Certificate {
|
||||
if (configuration.devMode) {
|
||||
configuration.configureWithDevSSLCertificate()
|
||||
}
|
||||
return validateKeyStore()
|
||||
return validateKeyStores()
|
||||
}
|
||||
|
||||
open fun generateAndSaveNodeInfo(): NodeInfo {
|
||||
check(started == null) { "Node has already been started" }
|
||||
log.info("Generating nodeInfo ...")
|
||||
val trustRoot = initKeyStore()
|
||||
val trustRoot = initKeyStores()
|
||||
val (identity, identityKeyPair) = obtainIdentity(notaryConfig = null)
|
||||
startDatabase()
|
||||
val nodeCa = configuration.signingCertificateStore.get()[X509Utilities.CORDA_CLIENT_CA]
|
||||
val nodeCa = configuration.signingCertificateStore.get()[CORDA_CLIENT_CA]
|
||||
identityService.start(trustRoot, listOf(identity.certificate, nodeCa))
|
||||
return database.use {
|
||||
it.transaction {
|
||||
@ -288,8 +292,8 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
|
||||
}
|
||||
log.info("Node starting up ...")
|
||||
|
||||
val trustRoot = initKeyStore()
|
||||
val nodeCa = configuration.signingCertificateStore.get()[X509Utilities.CORDA_CLIENT_CA]
|
||||
val trustRoot = initKeyStores()
|
||||
val nodeCa = configuration.signingCertificateStore.get()[CORDA_CLIENT_CA]
|
||||
initialiseJVMAgents()
|
||||
|
||||
schemaService.mappedSchemasWarnings().forEach {
|
||||
@ -313,7 +317,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
|
||||
servicesForResolution.start(netParams)
|
||||
networkMapCache.start(netParams.notaries)
|
||||
|
||||
startDatabase()
|
||||
startDatabase(metricRegistry)
|
||||
val (identity, identityKeyPair) = obtainIdentity(notaryConfig = null)
|
||||
identityService.start(trustRoot, listOf(identity.certificate, nodeCa))
|
||||
|
||||
@ -718,30 +722,50 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
|
||||
@VisibleForTesting
|
||||
protected open fun acceptableLiveFiberCountOnStop(): Int = 0
|
||||
|
||||
private fun validateKeyStore(): X509Certificate {
|
||||
val containCorrectKeys = try {
|
||||
// This will throw IOException if key file not found or KeyStoreException if keystore password is incorrect.
|
||||
val sslKeystore = configuration.p2pSslOptions.keyStore.get()
|
||||
val identitiesKeystore = configuration.signingCertificateStore.get()
|
||||
X509Utilities.CORDA_CLIENT_TLS in sslKeystore && X509Utilities.CORDA_CLIENT_CA in identitiesKeystore
|
||||
private fun getCertificateStores(): AllCertificateStores? {
|
||||
return try {
|
||||
// The following will throw IOException if key file not found or KeyStoreException if keystore password is incorrect.
|
||||
val sslKeyStore = configuration.p2pSslOptions.keyStore.get()
|
||||
val identitiesKeyStore = configuration.signingCertificateStore.get()
|
||||
val trustStore = configuration.p2pSslOptions.trustStore.get()
|
||||
AllCertificateStores(trustStore, sslKeyStore, identitiesKeyStore)
|
||||
} catch (e: KeyStoreException) {
|
||||
log.warn("Certificate key store found but key store password does not match configuration.")
|
||||
false
|
||||
log.warn("At least one of the keystores or truststore passwords does not match configuration.")
|
||||
null
|
||||
} catch (e: IOException) {
|
||||
log.error("IO exception while trying to validate keystore", e)
|
||||
false
|
||||
log.error("IO exception while trying to validate keystores and truststore", e)
|
||||
null
|
||||
}
|
||||
require(containCorrectKeys) {
|
||||
"Identity certificate not found. " +
|
||||
"Please either copy your existing identity key and certificate from another node, " +
|
||||
}
|
||||
|
||||
private data class AllCertificateStores(val trustStore: CertificateStore, val sslKeyStore: CertificateStore, val identitiesKeyStore: CertificateStore)
|
||||
|
||||
private fun validateKeyStores(): X509Certificate {
|
||||
// Step 1. Check trustStore, sslKeyStore and identitiesKeyStore exist.
|
||||
val certStores = requireNotNull(getCertificateStores()) {
|
||||
"One or more keyStores (identity or TLS) or trustStore not found. " +
|
||||
"Please either copy your existing keys and certificates from another node, " +
|
||||
"or if you don't have one yet, fill out the config file and run corda.jar --initial-registration. " +
|
||||
"Read more at: https://docs.corda.net/permissioning.html"
|
||||
}
|
||||
// Step 2. Check that trustStore contains the correct key-alias entry.
|
||||
require(CORDA_ROOT_CA in certStores.trustStore) {
|
||||
"Alias for trustRoot key not found. Please ensure you have an updated trustStore file."
|
||||
}
|
||||
// Step 3. Check that tls keyStore contains the correct key-alias entry.
|
||||
require(CORDA_CLIENT_TLS in certStores.sslKeyStore) {
|
||||
"Alias for TLS key not found. Please ensure you have an updated TLS keyStore file."
|
||||
}
|
||||
|
||||
// Check all cert path chain to the trusted root
|
||||
val sslCertChainRoot = configuration.p2pSslOptions.keyStore.get().query { getCertificateChain(X509Utilities.CORDA_CLIENT_TLS) }.last()
|
||||
val nodeCaCertChainRoot = configuration.signingCertificateStore.get().query { getCertificateChain(X509Utilities.CORDA_CLIENT_CA) }.last()
|
||||
val trustRoot = configuration.p2pSslOptions.trustStore.get()[X509Utilities.CORDA_ROOT_CA]
|
||||
// Step 4. Check that identity keyStores contain the correct key-alias entry for Node CA.
|
||||
require(CORDA_CLIENT_CA in certStores.identitiesKeyStore) {
|
||||
"Alias for Node CA key not found. Please ensure you have an updated identity keyStore file."
|
||||
}
|
||||
|
||||
// Step 5. Check all cert paths chain to the trusted root.
|
||||
val trustRoot = certStores.trustStore[CORDA_ROOT_CA]
|
||||
val sslCertChainRoot = certStores.sslKeyStore.query { getCertificateChain(CORDA_CLIENT_TLS) }.last()
|
||||
val nodeCaCertChainRoot = certStores.identitiesKeyStore.query { getCertificateChain(CORDA_CLIENT_CA) }.last()
|
||||
|
||||
require(sslCertChainRoot == trustRoot) { "TLS certificate must chain to the trusted root." }
|
||||
require(nodeCaCertChainRoot == trustRoot) { "Client CA certificate must chain to the trusted root." }
|
||||
@ -759,7 +783,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
|
||||
// Specific class so that MockNode can catch it.
|
||||
class DatabaseConfigurationException(msg: String) : CordaException(msg)
|
||||
|
||||
protected open fun startDatabase() {
|
||||
protected open fun startDatabase(metricRegistry: MetricRegistry? = null) {
|
||||
log.debug {
|
||||
val driverClasses = DriverManager.getDrivers().asSequence().map { it.javaClass.name }
|
||||
"Available JDBC drivers: $driverClasses"
|
||||
@ -768,7 +792,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
|
||||
if (props.isEmpty) throw DatabaseConfigurationException("There must be a database configured.")
|
||||
val isH2Database = isH2Database(props.getProperty("dataSource.url", ""))
|
||||
val schemas = if (isH2Database) schemaService.internalSchemas() else schemaService.schemaOptions.keys
|
||||
database.startHikariPool(props, configuration.database, schemas)
|
||||
database.startHikariPool(props, configuration.database, schemas, metricRegistry)
|
||||
// Now log the vendor string as this will also cause a connection to be tested eagerly.
|
||||
logVendorString(database, log)
|
||||
}
|
||||
@ -1060,9 +1084,9 @@ fun createCordaPersistence(databaseConfig: DatabaseConfig,
|
||||
return CordaPersistence(databaseConfig, schemaService.schemaOptions.keys, attributeConverters)
|
||||
}
|
||||
|
||||
fun CordaPersistence.startHikariPool(hikariProperties: Properties, databaseConfig: DatabaseConfig, schemas: Set<MappedSchema>) {
|
||||
fun CordaPersistence.startHikariPool(hikariProperties: Properties, databaseConfig: DatabaseConfig, schemas: Set<MappedSchema>, metricRegistry: MetricRegistry? = null) {
|
||||
try {
|
||||
val dataSource = DataSourceFactory.createDataSource(hikariProperties)
|
||||
val dataSource = DataSourceFactory.createDataSource(hikariProperties, metricRegistry = metricRegistry)
|
||||
val jdbcUrl = hikariProperties.getProperty("dataSource.url", "")
|
||||
val schemaMigration = SchemaMigration(schemas, dataSource, databaseConfig)
|
||||
schemaMigration.nodeStartup(dataSource.connection.use { DBCheckpointStorage().getCheckpointCount(it) != 0L }, isH2Database(jdbcUrl))
|
||||
@ -1085,4 +1109,4 @@ fun clientSslOptionsCompatibleWith(nodeRpcOptions: NodeRpcOptions): ClientRpcSsl
|
||||
}
|
||||
// Here we're using the node's RPC key store as the RPC client's trust store.
|
||||
return ClientRpcSslOptions(trustStorePath = nodeRpcOptions.sslConfig!!.keyStorePath, trustStorePassword = nodeRpcOptions.sslConfig!!.keyStorePassword)
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
package net.corda.node.internal
|
||||
|
||||
import com.codahale.metrics.MetricRegistry
|
||||
import com.zaxxer.hikari.HikariConfig
|
||||
import com.zaxxer.hikari.HikariDataSource
|
||||
import com.zaxxer.hikari.util.PropertyElf
|
||||
@ -35,10 +36,14 @@ object DataSourceFactory {
|
||||
}.set(null, SynchronizedGetPutRemove<String, Database>())
|
||||
}
|
||||
|
||||
fun createDataSource(hikariProperties: Properties, pool: Boolean = true): DataSource {
|
||||
fun createDataSource(hikariProperties: Properties, pool: Boolean = true, metricRegistry: MetricRegistry? = null): DataSource {
|
||||
val config = HikariConfig(hikariProperties)
|
||||
return if (pool) {
|
||||
HikariDataSource(config)
|
||||
val dataSource = HikariDataSource(config)
|
||||
if (metricRegistry != null) {
|
||||
dataSource.metricRegistry = metricRegistry
|
||||
}
|
||||
dataSource
|
||||
} else {
|
||||
// Basic init for the one test that wants to go via this API but without starting a HikariPool:
|
||||
(Class.forName(hikariProperties.getProperty("dataSourceClassName")).newInstance() as DataSource).also {
|
||||
|
@ -357,7 +357,7 @@ open class Node(configuration: NodeConfiguration,
|
||||
* This is not using the H2 "automatic mixed mode" directly but leans on many of the underpinnings. For more details
|
||||
* on H2 URLs and configuration see: http://www.h2database.com/html/features.html#database_url
|
||||
*/
|
||||
override fun startDatabase() {
|
||||
override fun startDatabase(metricRegistry: MetricRegistry?) {
|
||||
val databaseUrl = configuration.dataSourceProperties.getProperty("dataSource.url")
|
||||
val h2Prefix = "jdbc:h2:file:"
|
||||
|
||||
@ -396,7 +396,7 @@ open class Node(configuration: NodeConfiguration,
|
||||
printBasicNodeInfo("Database connection url is", databaseUrl)
|
||||
}
|
||||
|
||||
super.startDatabase()
|
||||
super.startDatabase(metricRegistry)
|
||||
database.closeOnStop()
|
||||
}
|
||||
|
||||
@ -514,4 +514,4 @@ open class Node(configuration: NodeConfiguration,
|
||||
|
||||
log.info("Shutdown complete")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -100,22 +100,6 @@ class TLSAuthenticationTests {
|
||||
testConnect(serverSocket, clientSocket, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `All EC K1`() {
|
||||
val (serverSocketFactory, clientSocketFactory) = buildTLSFactories(
|
||||
rootCAScheme = Crypto.ECDSA_SECP256K1_SHA256,
|
||||
intermediateCAScheme = Crypto.ECDSA_SECP256K1_SHA256,
|
||||
client1CAScheme = Crypto.ECDSA_SECP256K1_SHA256,
|
||||
client1TLSScheme = Crypto.ECDSA_SECP256K1_SHA256,
|
||||
client2CAScheme = Crypto.ECDSA_SECP256K1_SHA256,
|
||||
client2TLSScheme = Crypto.ECDSA_SECP256K1_SHA256
|
||||
)
|
||||
|
||||
val (serverSocket, clientSocket) = buildTLSSockets(serverSocketFactory, clientSocketFactory, 0, 0)
|
||||
|
||||
testConnect(serverSocket, clientSocket, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256")
|
||||
}
|
||||
|
||||
// Server's public key type is the one selected if users use different key types (e.g RSA and EC R1).
|
||||
@Test
|
||||
fun `Server RSA - Client EC R1 - CAs all EC R1`() {
|
||||
@ -162,22 +146,6 @@ class TLSAuthenticationTests {
|
||||
testConnect(serverSocket, clientSocket, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `Server EC K1 - Client EC R1 - CAs all RSA`() {
|
||||
val (serverSocketFactory, clientSocketFactory) = buildTLSFactories(
|
||||
rootCAScheme = Crypto.RSA_SHA256,
|
||||
intermediateCAScheme = Crypto.RSA_SHA256,
|
||||
client1CAScheme = Crypto.RSA_SHA256,
|
||||
client1TLSScheme = Crypto.ECDSA_SECP256K1_SHA256,
|
||||
client2CAScheme = Crypto.RSA_SHA256,
|
||||
client2TLSScheme = Crypto.ECDSA_SECP256R1_SHA256
|
||||
)
|
||||
|
||||
val (serverSocket, clientSocket) = buildTLSSockets(serverSocketFactory, clientSocketFactory, 0, 0)
|
||||
testConnect(serverSocket, clientSocket, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256")
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
fun `Server EC R1 - Client RSA - Mixed CAs`() {
|
||||
val (serverSocketFactory, clientSocketFactory) = buildTLSFactories(
|
||||
@ -185,7 +153,7 @@ class TLSAuthenticationTests {
|
||||
intermediateCAScheme = Crypto.RSA_SHA256,
|
||||
client1CAScheme = Crypto.RSA_SHA256,
|
||||
client1TLSScheme = Crypto.ECDSA_SECP256R1_SHA256,
|
||||
client2CAScheme = Crypto.ECDSA_SECP256K1_SHA256,
|
||||
client2CAScheme = Crypto.ECDSA_SECP256R1_SHA256,
|
||||
client2TLSScheme = Crypto.RSA_SHA256
|
||||
)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
package net.corda.testing.node.internal
|
||||
|
||||
import com.codahale.metrics.MetricRegistry
|
||||
import com.google.common.jimfs.Configuration.unix
|
||||
import com.google.common.jimfs.Jimfs
|
||||
import com.nhaarman.mockito_kotlin.doReturn
|
||||
@ -404,8 +405,8 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe
|
||||
get() = _serializationWhitelists
|
||||
private var dbCloser: (() -> Any?)? = null
|
||||
|
||||
override fun startDatabase() {
|
||||
super.startDatabase()
|
||||
override fun startDatabase(metricRegistry: MetricRegistry?) {
|
||||
super.startDatabase(metricRegistry)
|
||||
dbCloser = database::close
|
||||
runOnStop += dbCloser!!
|
||||
}
|
||||
|
@ -3,10 +3,16 @@ package net.corda.tools.shell
|
||||
import com.google.common.io.Files
|
||||
import com.jcraft.jsch.ChannelExec
|
||||
import com.jcraft.jsch.JSch
|
||||
import com.nhaarman.mockito_kotlin.any
|
||||
import com.nhaarman.mockito_kotlin.doAnswer
|
||||
import com.nhaarman.mockito_kotlin.mock
|
||||
import net.corda.client.rpc.RPCException
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.StartableByRPC
|
||||
import net.corda.core.internal.div
|
||||
import net.corda.core.messaging.ClientRpcSslOptions
|
||||
import net.corda.core.messaging.CordaRPCOps
|
||||
import net.corda.core.utilities.ProgressTracker
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.node.services.Permissions
|
||||
import net.corda.node.services.Permissions.Companion.all
|
||||
@ -27,10 +33,12 @@ import net.corda.testing.internal.IntegrationTestSchemas
|
||||
import net.corda.testing.internal.toDatabaseSchemaName
|
||||
import net.corda.testing.internal.useSslRpcOverrides
|
||||
import net.corda.testing.node.User
|
||||
import net.corda.tools.shell.utlities.ANSIProgressRenderer
|
||||
import org.apache.activemq.artemis.api.core.ActiveMQSecurityException
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||
import org.bouncycastle.util.io.Streams
|
||||
import org.crsh.text.RenderPrintWriter
|
||||
import org.junit.ClassRule
|
||||
import org.junit.Ignore
|
||||
import org.junit.Rule
|
||||
@ -250,4 +258,159 @@ class InteractiveShellIntegrationTest : IntegrationTest() {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNUSED")
|
||||
@StartableByRPC
|
||||
class NoOpFlow : FlowLogic<Unit>() {
|
||||
override val progressTracker = ProgressTracker()
|
||||
override fun call() {
|
||||
println("NO OP!")
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNUSED")
|
||||
@StartableByRPC
|
||||
class NoOpFlowA : FlowLogic<Unit>() {
|
||||
override val progressTracker = ProgressTracker()
|
||||
override fun call() {
|
||||
println("NO OP! (A)")
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNUSED")
|
||||
@StartableByRPC
|
||||
class BurbleFlow : FlowLogic<Unit>() {
|
||||
override val progressTracker = ProgressTracker()
|
||||
override fun call() {
|
||||
println("NO OP! (Burble)")
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `shell should start flow with fully qualified class name`() {
|
||||
val user = User("u", "p", setOf(all()))
|
||||
var successful = false
|
||||
driver(DriverParameters(notarySpecs = emptyList())) {
|
||||
val nodeFuture = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), startInSameProcess = true)
|
||||
val node = nodeFuture.getOrThrow()
|
||||
|
||||
val conf = ShellConfiguration(commandsDirectory = Files.createTempDir().toPath(),
|
||||
user = user.username, password = user.password,
|
||||
hostAndPort = node.rpcAddress)
|
||||
InteractiveShell.startShell(conf)
|
||||
|
||||
// setup and configure some mocks required by InteractiveShell.runFlowByNameFragment()
|
||||
val output = mock<RenderPrintWriter> {
|
||||
on { println(any<String>()) } doAnswer {
|
||||
val line = it.arguments[0]
|
||||
println("$line")
|
||||
if ((line is String) && (line.startsWith("Flow completed with result:")))
|
||||
successful = true
|
||||
}
|
||||
}
|
||||
val ansiProgressRenderer = mock<ANSIProgressRenderer> {
|
||||
on { render(any(), any()) } doAnswer { InteractiveShell.latch.countDown() }
|
||||
}
|
||||
InteractiveShell.runFlowByNameFragment(
|
||||
InteractiveShellIntegrationTest::class.qualifiedName + "\$NoOpFlow",
|
||||
"", output, node.rpc, ansiProgressRenderer)
|
||||
}
|
||||
assertThat(successful).isTrue()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `shell should start flow with unique un-qualified class name`() {
|
||||
val user = User("u", "p", setOf(all()))
|
||||
var successful = false
|
||||
driver(DriverParameters(notarySpecs = emptyList())) {
|
||||
val nodeFuture = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), startInSameProcess = true)
|
||||
val node = nodeFuture.getOrThrow()
|
||||
|
||||
val conf = ShellConfiguration(commandsDirectory = Files.createTempDir().toPath(),
|
||||
user = user.username, password = user.password,
|
||||
hostAndPort = node.rpcAddress)
|
||||
InteractiveShell.startShell(conf)
|
||||
|
||||
// setup and configure some mocks required by InteractiveShell.runFlowByNameFragment()
|
||||
val output = mock<RenderPrintWriter> {
|
||||
on { println(any<String>()) } doAnswer {
|
||||
val line = it.arguments[0]
|
||||
println("$line")
|
||||
if ((line is String) && (line.startsWith("Flow completed with result:")))
|
||||
successful = true
|
||||
}
|
||||
}
|
||||
val ansiProgressRenderer = mock<ANSIProgressRenderer> {
|
||||
on { render(any(), any()) } doAnswer { InteractiveShell.latch.countDown() }
|
||||
}
|
||||
InteractiveShell.runFlowByNameFragment(
|
||||
"InteractiveShellIntegrationTest\$NoOpFlowA",
|
||||
"", output, node.rpc, ansiProgressRenderer)
|
||||
}
|
||||
assertThat(successful).isTrue()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `shell should fail to start flow with ambiguous class name`() {
|
||||
val user = User("u", "p", setOf(all()))
|
||||
var successful = false
|
||||
driver(DriverParameters(notarySpecs = emptyList())) {
|
||||
val nodeFuture = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), startInSameProcess = true)
|
||||
val node = nodeFuture.getOrThrow()
|
||||
|
||||
val conf = ShellConfiguration(commandsDirectory = Files.createTempDir().toPath(),
|
||||
user = user.username, password = user.password,
|
||||
hostAndPort = node.rpcAddress)
|
||||
InteractiveShell.startShell(conf)
|
||||
|
||||
// setup and configure some mocks required by InteractiveShell.runFlowByNameFragment()
|
||||
val output = mock<RenderPrintWriter> {
|
||||
on { println(any<String>()) } doAnswer {
|
||||
val line = it.arguments[0]
|
||||
println("$line")
|
||||
if ((line is String) && (line.startsWith("Ambiguous name provided, please be more specific.")))
|
||||
successful = true
|
||||
}
|
||||
}
|
||||
val ansiProgressRenderer = mock<ANSIProgressRenderer> {
|
||||
on { render(any(), any()) } doAnswer { InteractiveShell.latch.countDown() }
|
||||
}
|
||||
InteractiveShell.runFlowByNameFragment(
|
||||
InteractiveShellIntegrationTest::class.qualifiedName + "\$NoOpFlo",
|
||||
"", output, node.rpc, ansiProgressRenderer)
|
||||
}
|
||||
assertThat(successful).isTrue()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `shell should start flow with partially matching class name`() {
|
||||
val user = User("u", "p", setOf(all()))
|
||||
var successful = false
|
||||
driver(DriverParameters(notarySpecs = emptyList())) {
|
||||
val nodeFuture = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), startInSameProcess = true)
|
||||
val node = nodeFuture.getOrThrow()
|
||||
|
||||
val conf = ShellConfiguration(commandsDirectory = Files.createTempDir().toPath(),
|
||||
user = user.username, password = user.password,
|
||||
hostAndPort = node.rpcAddress)
|
||||
InteractiveShell.startShell(conf)
|
||||
|
||||
// setup and configure some mocks required by InteractiveShell.runFlowByNameFragment()
|
||||
val output = mock<RenderPrintWriter> {
|
||||
on { println(any<String>()) } doAnswer {
|
||||
val line = it.arguments[0]
|
||||
println("$line")
|
||||
if ((line is String) && (line.startsWith("Flow completed with result")))
|
||||
successful = true
|
||||
}
|
||||
}
|
||||
val ansiProgressRenderer = mock<ANSIProgressRenderer> {
|
||||
on { render(any(), any()) } doAnswer { InteractiveShell.latch.countDown() }
|
||||
}
|
||||
InteractiveShell.runFlowByNameFragment(
|
||||
"Burble",
|
||||
"", output, node.rpc, ansiProgressRenderer)
|
||||
}
|
||||
assertThat(successful).isTrue()
|
||||
}
|
||||
}
|
||||
|
@ -209,6 +209,10 @@ object InteractiveShell {
|
||||
// TODO: This should become the default renderer rather than something used specifically by commands.
|
||||
private val outputMapper by lazy { createOutputMapper() }
|
||||
|
||||
@VisibleForTesting
|
||||
lateinit var latch: CountDownLatch
|
||||
private set
|
||||
|
||||
/**
|
||||
* Called from the 'flow' shell command. Takes a name fragment and finds a matching flow, or prints out
|
||||
* the list of options if the request is ambiguous. Then parses [inputData] as constructor arguments using
|
||||
@ -220,7 +224,7 @@ object InteractiveShell {
|
||||
output: RenderPrintWriter,
|
||||
rpcOps: CordaRPCOps,
|
||||
ansiProgressRenderer: ANSIProgressRenderer,
|
||||
om: ObjectMapper) {
|
||||
om: ObjectMapper = outputMapper) {
|
||||
val matches = try {
|
||||
rpcOps.registeredFlows().filter { nameFragment in it }
|
||||
} catch (e: PermissionException) {
|
||||
@ -230,23 +234,24 @@ object InteractiveShell {
|
||||
if (matches.isEmpty()) {
|
||||
output.println("No matching flow found, run 'flow list' to see your options.", Color.red)
|
||||
return
|
||||
} else if (matches.size > 1) {
|
||||
} else if (matches.size > 1 && matches.find { it.endsWith(nameFragment)} == null) {
|
||||
output.println("Ambiguous name provided, please be more specific. Your options are:")
|
||||
matches.forEachIndexed { i, s -> output.println("${i + 1}. $s", Color.yellow) }
|
||||
return
|
||||
}
|
||||
|
||||
val flowName = matches.find { it.endsWith(nameFragment)} ?: matches.single()
|
||||
val flowClazz: Class<FlowLogic<*>> = if (classLoader != null) {
|
||||
uncheckedCast(Class.forName(matches.single(), true, classLoader))
|
||||
uncheckedCast(Class.forName(flowName, true, classLoader))
|
||||
} else {
|
||||
uncheckedCast(Class.forName(matches.single()))
|
||||
uncheckedCast(Class.forName(flowName))
|
||||
}
|
||||
try {
|
||||
// Show the progress tracker on the console until the flow completes or is interrupted with a
|
||||
// Ctrl-C keypress.
|
||||
val stateObservable = runFlowFromString({ clazz, args -> rpcOps.startTrackedFlowDynamic(clazz, *args) }, inputData, flowClazz, om)
|
||||
|
||||
val latch = CountDownLatch(1)
|
||||
latch = CountDownLatch(1)
|
||||
ansiProgressRenderer.render(stateObservable, latch::countDown)
|
||||
// Wait for the flow to end and the progress tracker to notice. By the time the latch is released
|
||||
// the tracker is done with the screen.
|
||||
|
Loading…
x
Reference in New Issue
Block a user