Merge pull request #1386 from corda/os-merge-7459115

O/S merge 7459115
This commit is contained in:
Shams Asari 2018-09-11 15:00:34 +01:00 committed by GitHub
commit 26af54412b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 250 additions and 81 deletions

View File

@ -73,7 +73,7 @@ are compatible with TLS 1.2, while the default scheme per key type is also shown
| | and SHA-256 | | vendors. | | - tls |
| | | | - network map (CN) |
+-------------------------+---------------------------------------------------------------+-----+-------------------------+
| | ECDSA using the | | secp256k1 is the curve adopted by Bitcoin and as such there | YES | |
| | ECDSA using the | | secp256k1 is the curve adopted by Bitcoin and as such there | NO | |
| | Koblitz k1 curve | | is a wealth of infrastructure, code and advanced algorithms | | |
| | (secp256k1) | | designed for use with it. This curve is standardised by | | |
| | and SHA-256 | | NIST as part of the "Suite B" cryptographic algorithms and | | |

View File

@ -31,14 +31,14 @@ class NodeKeystoreCheckTest : IntegrationTest() {
driver(DriverParameters(startNodesInProcess = true, notarySpecs = emptyList())) {
assertThatThrownBy {
startNode(customOverrides = mapOf("devMode" to false)).getOrThrow()
}.hasMessageContaining("Identity certificate not found")
}.hasMessageContaining("One or more keyStores (identity or TLS) or trustStore not found.")
}
}
@Test
fun `node should throw exception if cert path doesn't chain to the trust root`() {
driver(DriverParameters(startNodesInProcess = true, notarySpecs = emptyList())) {
// Create keystores
// Create keystores.
val keystorePassword = "password"
val certificatesDirectory = baseDirectory(ALICE_NAME) / "certificates"
val signingCertStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory, keystorePassword)
@ -57,7 +57,7 @@ class NodeKeystoreCheckTest : IntegrationTest() {
// Fiddle with node keystore.
signingCertStore.get().update {
// Self signed root
// Self signed root.
val badRootKeyPair = Crypto.generateKeyPair()
val badRoot = X509Utilities.createSelfSignedCACertificate(X500Principal("O=Bad Root,L=Lodnon,C=GB"), badRootKeyPair)
val nodeCA = getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA)

View File

@ -18,13 +18,13 @@ import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.ArtemisTcpTransport
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.nodeapi.internal.crypto.*
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPConfiguration
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.nodeapi.internal.registerDevSigningCertificates
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
@ -489,7 +489,10 @@ class ProtonWrapperTests {
sharedThreadPool = sharedEventGroup)
}
private fun createServer(port: Int, name: CordaX500Name = ALICE_NAME, maxMessageSize: Int = MAX_MESSAGE_SIZE, crlCheckSoftFail: Boolean = true): AMQPServer {
private fun createServer(port: Int,
name: CordaX500Name = ALICE_NAME,
maxMessageSize: Int = MAX_MESSAGE_SIZE,
crlCheckSoftFail: Boolean = true): AMQPServer {
val baseDirectory = temporaryFolder.root.toPath() / "server"
val certificatesDirectory = baseDirectory / "certificates"
val signingCertificateStore = CertificateStoreStubs.Signing.withCertificatesDirectory(certificatesDirectory)

View File

@ -72,7 +72,11 @@ import net.corda.node.utilities.NodeBuildProperties
import net.corda.nodeapi.internal.DevIdentityGenerator
import net.corda.nodeapi.internal.NodeInfoAndSigned
import net.corda.nodeapi.internal.SignedNodeInfo
import net.corda.nodeapi.internal.config.CertificateStore
import net.corda.nodeapi.internal.crypto.X509Utilities
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_CA
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_CLIENT_TLS
import net.corda.nodeapi.internal.crypto.X509Utilities.CORDA_ROOT_CA
import net.corda.nodeapi.internal.crypto.X509Utilities.DISTRIBUTED_NOTARY_ALIAS_PREFIX
import net.corda.nodeapi.internal.crypto.X509Utilities.NODE_IDENTITY_ALIAS_PREFIX
import net.corda.nodeapi.internal.persistence.*
@ -248,20 +252,20 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
return proxies.fold(ops) { delegate, decorate -> decorate(delegate) }
}
private fun initKeyStore(): X509Certificate {
private fun initKeyStores(): X509Certificate {
if (configuration.devMode) {
configuration.configureWithDevSSLCertificate()
}
return validateKeyStore()
return validateKeyStores()
}
open fun generateAndSaveNodeInfo(): NodeInfo {
check(started == null) { "Node has already been started" }
log.info("Generating nodeInfo ...")
val trustRoot = initKeyStore()
val trustRoot = initKeyStores()
val (identity, identityKeyPair) = obtainIdentity(notaryConfig = null)
startDatabase()
val nodeCa = configuration.signingCertificateStore.get()[X509Utilities.CORDA_CLIENT_CA]
val nodeCa = configuration.signingCertificateStore.get()[CORDA_CLIENT_CA]
identityService.start(trustRoot, listOf(identity.certificate, nodeCa))
return database.use {
it.transaction {
@ -288,8 +292,8 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
}
log.info("Node starting up ...")
val trustRoot = initKeyStore()
val nodeCa = configuration.signingCertificateStore.get()[X509Utilities.CORDA_CLIENT_CA]
val trustRoot = initKeyStores()
val nodeCa = configuration.signingCertificateStore.get()[CORDA_CLIENT_CA]
initialiseJVMAgents()
schemaService.mappedSchemasWarnings().forEach {
@ -313,7 +317,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
servicesForResolution.start(netParams)
networkMapCache.start(netParams.notaries)
startDatabase()
startDatabase(metricRegistry)
val (identity, identityKeyPair) = obtainIdentity(notaryConfig = null)
identityService.start(trustRoot, listOf(identity.certificate, nodeCa))
@ -718,30 +722,50 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
@VisibleForTesting
protected open fun acceptableLiveFiberCountOnStop(): Int = 0
private fun validateKeyStore(): X509Certificate {
val containCorrectKeys = try {
// This will throw IOException if key file not found or KeyStoreException if keystore password is incorrect.
val sslKeystore = configuration.p2pSslOptions.keyStore.get()
val identitiesKeystore = configuration.signingCertificateStore.get()
X509Utilities.CORDA_CLIENT_TLS in sslKeystore && X509Utilities.CORDA_CLIENT_CA in identitiesKeystore
private fun getCertificateStores(): AllCertificateStores? {
return try {
// The following will throw IOException if key file not found or KeyStoreException if keystore password is incorrect.
val sslKeyStore = configuration.p2pSslOptions.keyStore.get()
val identitiesKeyStore = configuration.signingCertificateStore.get()
val trustStore = configuration.p2pSslOptions.trustStore.get()
AllCertificateStores(trustStore, sslKeyStore, identitiesKeyStore)
} catch (e: KeyStoreException) {
log.warn("Certificate key store found but key store password does not match configuration.")
false
log.warn("At least one of the keystores or truststore passwords does not match configuration.")
null
} catch (e: IOException) {
log.error("IO exception while trying to validate keystore", e)
false
log.error("IO exception while trying to validate keystores and truststore", e)
null
}
require(containCorrectKeys) {
"Identity certificate not found. " +
"Please either copy your existing identity key and certificate from another node, " +
}
private data class AllCertificateStores(val trustStore: CertificateStore, val sslKeyStore: CertificateStore, val identitiesKeyStore: CertificateStore)
private fun validateKeyStores(): X509Certificate {
// Step 1. Check trustStore, sslKeyStore and identitiesKeyStore exist.
val certStores = requireNotNull(getCertificateStores()) {
"One or more keyStores (identity or TLS) or trustStore not found. " +
"Please either copy your existing keys and certificates from another node, " +
"or if you don't have one yet, fill out the config file and run corda.jar --initial-registration. " +
"Read more at: https://docs.corda.net/permissioning.html"
}
// Step 2. Check that trustStore contains the correct key-alias entry.
require(CORDA_ROOT_CA in certStores.trustStore) {
"Alias for trustRoot key not found. Please ensure you have an updated trustStore file."
}
// Step 3. Check that tls keyStore contains the correct key-alias entry.
require(CORDA_CLIENT_TLS in certStores.sslKeyStore) {
"Alias for TLS key not found. Please ensure you have an updated TLS keyStore file."
}
// Check all cert path chain to the trusted root
val sslCertChainRoot = configuration.p2pSslOptions.keyStore.get().query { getCertificateChain(X509Utilities.CORDA_CLIENT_TLS) }.last()
val nodeCaCertChainRoot = configuration.signingCertificateStore.get().query { getCertificateChain(X509Utilities.CORDA_CLIENT_CA) }.last()
val trustRoot = configuration.p2pSslOptions.trustStore.get()[X509Utilities.CORDA_ROOT_CA]
// Step 4. Check that identity keyStores contain the correct key-alias entry for Node CA.
require(CORDA_CLIENT_CA in certStores.identitiesKeyStore) {
"Alias for Node CA key not found. Please ensure you have an updated identity keyStore file."
}
// Step 5. Check all cert paths chain to the trusted root.
val trustRoot = certStores.trustStore[CORDA_ROOT_CA]
val sslCertChainRoot = certStores.sslKeyStore.query { getCertificateChain(CORDA_CLIENT_TLS) }.last()
val nodeCaCertChainRoot = certStores.identitiesKeyStore.query { getCertificateChain(CORDA_CLIENT_CA) }.last()
require(sslCertChainRoot == trustRoot) { "TLS certificate must chain to the trusted root." }
require(nodeCaCertChainRoot == trustRoot) { "Client CA certificate must chain to the trusted root." }
@ -759,7 +783,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
// Specific class so that MockNode can catch it.
class DatabaseConfigurationException(msg: String) : CordaException(msg)
protected open fun startDatabase() {
protected open fun startDatabase(metricRegistry: MetricRegistry? = null) {
log.debug {
val driverClasses = DriverManager.getDrivers().asSequence().map { it.javaClass.name }
"Available JDBC drivers: $driverClasses"
@ -768,7 +792,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
if (props.isEmpty) throw DatabaseConfigurationException("There must be a database configured.")
val isH2Database = isH2Database(props.getProperty("dataSource.url", ""))
val schemas = if (isH2Database) schemaService.internalSchemas() else schemaService.schemaOptions.keys
database.startHikariPool(props, configuration.database, schemas)
database.startHikariPool(props, configuration.database, schemas, metricRegistry)
// Now log the vendor string as this will also cause a connection to be tested eagerly.
logVendorString(database, log)
}
@ -1060,9 +1084,9 @@ fun createCordaPersistence(databaseConfig: DatabaseConfig,
return CordaPersistence(databaseConfig, schemaService.schemaOptions.keys, attributeConverters)
}
fun CordaPersistence.startHikariPool(hikariProperties: Properties, databaseConfig: DatabaseConfig, schemas: Set<MappedSchema>) {
fun CordaPersistence.startHikariPool(hikariProperties: Properties, databaseConfig: DatabaseConfig, schemas: Set<MappedSchema>, metricRegistry: MetricRegistry? = null) {
try {
val dataSource = DataSourceFactory.createDataSource(hikariProperties)
val dataSource = DataSourceFactory.createDataSource(hikariProperties, metricRegistry = metricRegistry)
val jdbcUrl = hikariProperties.getProperty("dataSource.url", "")
val schemaMigration = SchemaMigration(schemas, dataSource, databaseConfig)
schemaMigration.nodeStartup(dataSource.connection.use { DBCheckpointStorage().getCheckpointCount(it) != 0L }, isH2Database(jdbcUrl))
@ -1085,4 +1109,4 @@ fun clientSslOptionsCompatibleWith(nodeRpcOptions: NodeRpcOptions): ClientRpcSsl
}
// Here we're using the node's RPC key store as the RPC client's trust store.
return ClientRpcSslOptions(trustStorePath = nodeRpcOptions.sslConfig!!.keyStorePath, trustStorePassword = nodeRpcOptions.sslConfig!!.keyStorePassword)
}
}

View File

@ -1,5 +1,6 @@
package net.corda.node.internal
import com.codahale.metrics.MetricRegistry
import com.zaxxer.hikari.HikariConfig
import com.zaxxer.hikari.HikariDataSource
import com.zaxxer.hikari.util.PropertyElf
@ -35,10 +36,14 @@ object DataSourceFactory {
}.set(null, SynchronizedGetPutRemove<String, Database>())
}
fun createDataSource(hikariProperties: Properties, pool: Boolean = true): DataSource {
fun createDataSource(hikariProperties: Properties, pool: Boolean = true, metricRegistry: MetricRegistry? = null): DataSource {
val config = HikariConfig(hikariProperties)
return if (pool) {
HikariDataSource(config)
val dataSource = HikariDataSource(config)
if (metricRegistry != null) {
dataSource.metricRegistry = metricRegistry
}
dataSource
} else {
// Basic init for the one test that wants to go via this API but without starting a HikariPool:
(Class.forName(hikariProperties.getProperty("dataSourceClassName")).newInstance() as DataSource).also {

View File

@ -357,7 +357,7 @@ open class Node(configuration: NodeConfiguration,
* This is not using the H2 "automatic mixed mode" directly but leans on many of the underpinnings. For more details
* on H2 URLs and configuration see: http://www.h2database.com/html/features.html#database_url
*/
override fun startDatabase() {
override fun startDatabase(metricRegistry: MetricRegistry?) {
val databaseUrl = configuration.dataSourceProperties.getProperty("dataSource.url")
val h2Prefix = "jdbc:h2:file:"
@ -396,7 +396,7 @@ open class Node(configuration: NodeConfiguration,
printBasicNodeInfo("Database connection url is", databaseUrl)
}
super.startDatabase()
super.startDatabase(metricRegistry)
database.closeOnStop()
}
@ -514,4 +514,4 @@ open class Node(configuration: NodeConfiguration,
log.info("Shutdown complete")
}
}
}

View File

@ -100,22 +100,6 @@ class TLSAuthenticationTests {
testConnect(serverSocket, clientSocket, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256")
}
@Test
fun `All EC K1`() {
val (serverSocketFactory, clientSocketFactory) = buildTLSFactories(
rootCAScheme = Crypto.ECDSA_SECP256K1_SHA256,
intermediateCAScheme = Crypto.ECDSA_SECP256K1_SHA256,
client1CAScheme = Crypto.ECDSA_SECP256K1_SHA256,
client1TLSScheme = Crypto.ECDSA_SECP256K1_SHA256,
client2CAScheme = Crypto.ECDSA_SECP256K1_SHA256,
client2TLSScheme = Crypto.ECDSA_SECP256K1_SHA256
)
val (serverSocket, clientSocket) = buildTLSSockets(serverSocketFactory, clientSocketFactory, 0, 0)
testConnect(serverSocket, clientSocket, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256")
}
// Server's public key type is the one selected if users use different key types (e.g RSA and EC R1).
@Test
fun `Server RSA - Client EC R1 - CAs all EC R1`() {
@ -162,22 +146,6 @@ class TLSAuthenticationTests {
testConnect(serverSocket, clientSocket, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256")
}
@Test
fun `Server EC K1 - Client EC R1 - CAs all RSA`() {
val (serverSocketFactory, clientSocketFactory) = buildTLSFactories(
rootCAScheme = Crypto.RSA_SHA256,
intermediateCAScheme = Crypto.RSA_SHA256,
client1CAScheme = Crypto.RSA_SHA256,
client1TLSScheme = Crypto.ECDSA_SECP256K1_SHA256,
client2CAScheme = Crypto.RSA_SHA256,
client2TLSScheme = Crypto.ECDSA_SECP256R1_SHA256
)
val (serverSocket, clientSocket) = buildTLSSockets(serverSocketFactory, clientSocketFactory, 0, 0)
testConnect(serverSocket, clientSocket, "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256")
}
@Test
fun `Server EC R1 - Client RSA - Mixed CAs`() {
val (serverSocketFactory, clientSocketFactory) = buildTLSFactories(
@ -185,7 +153,7 @@ class TLSAuthenticationTests {
intermediateCAScheme = Crypto.RSA_SHA256,
client1CAScheme = Crypto.RSA_SHA256,
client1TLSScheme = Crypto.ECDSA_SECP256R1_SHA256,
client2CAScheme = Crypto.ECDSA_SECP256K1_SHA256,
client2CAScheme = Crypto.ECDSA_SECP256R1_SHA256,
client2TLSScheme = Crypto.RSA_SHA256
)

View File

@ -1,5 +1,6 @@
package net.corda.testing.node.internal
import com.codahale.metrics.MetricRegistry
import com.google.common.jimfs.Configuration.unix
import com.google.common.jimfs.Jimfs
import com.nhaarman.mockito_kotlin.doReturn
@ -404,8 +405,8 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe
get() = _serializationWhitelists
private var dbCloser: (() -> Any?)? = null
override fun startDatabase() {
super.startDatabase()
override fun startDatabase(metricRegistry: MetricRegistry?) {
super.startDatabase(metricRegistry)
dbCloser = database::close
runOnStop += dbCloser!!
}

View File

@ -3,10 +3,16 @@ package net.corda.tools.shell
import com.google.common.io.Files
import com.jcraft.jsch.ChannelExec
import com.jcraft.jsch.JSch
import com.nhaarman.mockito_kotlin.any
import com.nhaarman.mockito_kotlin.doAnswer
import com.nhaarman.mockito_kotlin.mock
import net.corda.client.rpc.RPCException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StartableByRPC
import net.corda.core.internal.div
import net.corda.core.messaging.ClientRpcSslOptions
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.getOrThrow
import net.corda.node.services.Permissions
import net.corda.node.services.Permissions.Companion.all
@ -27,10 +33,12 @@ import net.corda.testing.internal.IntegrationTestSchemas
import net.corda.testing.internal.toDatabaseSchemaName
import net.corda.testing.internal.useSslRpcOverrides
import net.corda.testing.node.User
import net.corda.tools.shell.utlities.ANSIProgressRenderer
import org.apache.activemq.artemis.api.core.ActiveMQSecurityException
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.bouncycastle.util.io.Streams
import org.crsh.text.RenderPrintWriter
import org.junit.ClassRule
import org.junit.Ignore
import org.junit.Rule
@ -250,4 +258,159 @@ class InteractiveShellIntegrationTest : IntegrationTest() {
}
}
@Suppress("UNUSED")
@StartableByRPC
class NoOpFlow : FlowLogic<Unit>() {
override val progressTracker = ProgressTracker()
override fun call() {
println("NO OP!")
}
}
@Suppress("UNUSED")
@StartableByRPC
class NoOpFlowA : FlowLogic<Unit>() {
override val progressTracker = ProgressTracker()
override fun call() {
println("NO OP! (A)")
}
}
@Suppress("UNUSED")
@StartableByRPC
class BurbleFlow : FlowLogic<Unit>() {
override val progressTracker = ProgressTracker()
override fun call() {
println("NO OP! (Burble)")
}
}
@Test
fun `shell should start flow with fully qualified class name`() {
val user = User("u", "p", setOf(all()))
var successful = false
driver(DriverParameters(notarySpecs = emptyList())) {
val nodeFuture = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), startInSameProcess = true)
val node = nodeFuture.getOrThrow()
val conf = ShellConfiguration(commandsDirectory = Files.createTempDir().toPath(),
user = user.username, password = user.password,
hostAndPort = node.rpcAddress)
InteractiveShell.startShell(conf)
// setup and configure some mocks required by InteractiveShell.runFlowByNameFragment()
val output = mock<RenderPrintWriter> {
on { println(any<String>()) } doAnswer {
val line = it.arguments[0]
println("$line")
if ((line is String) && (line.startsWith("Flow completed with result:")))
successful = true
}
}
val ansiProgressRenderer = mock<ANSIProgressRenderer> {
on { render(any(), any()) } doAnswer { InteractiveShell.latch.countDown() }
}
InteractiveShell.runFlowByNameFragment(
InteractiveShellIntegrationTest::class.qualifiedName + "\$NoOpFlow",
"", output, node.rpc, ansiProgressRenderer)
}
assertThat(successful).isTrue()
}
@Test
fun `shell should start flow with unique un-qualified class name`() {
val user = User("u", "p", setOf(all()))
var successful = false
driver(DriverParameters(notarySpecs = emptyList())) {
val nodeFuture = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), startInSameProcess = true)
val node = nodeFuture.getOrThrow()
val conf = ShellConfiguration(commandsDirectory = Files.createTempDir().toPath(),
user = user.username, password = user.password,
hostAndPort = node.rpcAddress)
InteractiveShell.startShell(conf)
// setup and configure some mocks required by InteractiveShell.runFlowByNameFragment()
val output = mock<RenderPrintWriter> {
on { println(any<String>()) } doAnswer {
val line = it.arguments[0]
println("$line")
if ((line is String) && (line.startsWith("Flow completed with result:")))
successful = true
}
}
val ansiProgressRenderer = mock<ANSIProgressRenderer> {
on { render(any(), any()) } doAnswer { InteractiveShell.latch.countDown() }
}
InteractiveShell.runFlowByNameFragment(
"InteractiveShellIntegrationTest\$NoOpFlowA",
"", output, node.rpc, ansiProgressRenderer)
}
assertThat(successful).isTrue()
}
@Test
fun `shell should fail to start flow with ambiguous class name`() {
val user = User("u", "p", setOf(all()))
var successful = false
driver(DriverParameters(notarySpecs = emptyList())) {
val nodeFuture = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), startInSameProcess = true)
val node = nodeFuture.getOrThrow()
val conf = ShellConfiguration(commandsDirectory = Files.createTempDir().toPath(),
user = user.username, password = user.password,
hostAndPort = node.rpcAddress)
InteractiveShell.startShell(conf)
// setup and configure some mocks required by InteractiveShell.runFlowByNameFragment()
val output = mock<RenderPrintWriter> {
on { println(any<String>()) } doAnswer {
val line = it.arguments[0]
println("$line")
if ((line is String) && (line.startsWith("Ambiguous name provided, please be more specific.")))
successful = true
}
}
val ansiProgressRenderer = mock<ANSIProgressRenderer> {
on { render(any(), any()) } doAnswer { InteractiveShell.latch.countDown() }
}
InteractiveShell.runFlowByNameFragment(
InteractiveShellIntegrationTest::class.qualifiedName + "\$NoOpFlo",
"", output, node.rpc, ansiProgressRenderer)
}
assertThat(successful).isTrue()
}
@Test
fun `shell should start flow with partially matching class name`() {
val user = User("u", "p", setOf(all()))
var successful = false
driver(DriverParameters(notarySpecs = emptyList())) {
val nodeFuture = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), startInSameProcess = true)
val node = nodeFuture.getOrThrow()
val conf = ShellConfiguration(commandsDirectory = Files.createTempDir().toPath(),
user = user.username, password = user.password,
hostAndPort = node.rpcAddress)
InteractiveShell.startShell(conf)
// setup and configure some mocks required by InteractiveShell.runFlowByNameFragment()
val output = mock<RenderPrintWriter> {
on { println(any<String>()) } doAnswer {
val line = it.arguments[0]
println("$line")
if ((line is String) && (line.startsWith("Flow completed with result")))
successful = true
}
}
val ansiProgressRenderer = mock<ANSIProgressRenderer> {
on { render(any(), any()) } doAnswer { InteractiveShell.latch.countDown() }
}
InteractiveShell.runFlowByNameFragment(
"Burble",
"", output, node.rpc, ansiProgressRenderer)
}
assertThat(successful).isTrue()
}
}

View File

@ -209,6 +209,10 @@ object InteractiveShell {
// TODO: This should become the default renderer rather than something used specifically by commands.
private val outputMapper by lazy { createOutputMapper() }
@VisibleForTesting
lateinit var latch: CountDownLatch
private set
/**
* Called from the 'flow' shell command. Takes a name fragment and finds a matching flow, or prints out
* the list of options if the request is ambiguous. Then parses [inputData] as constructor arguments using
@ -220,7 +224,7 @@ object InteractiveShell {
output: RenderPrintWriter,
rpcOps: CordaRPCOps,
ansiProgressRenderer: ANSIProgressRenderer,
om: ObjectMapper) {
om: ObjectMapper = outputMapper) {
val matches = try {
rpcOps.registeredFlows().filter { nameFragment in it }
} catch (e: PermissionException) {
@ -230,23 +234,24 @@ object InteractiveShell {
if (matches.isEmpty()) {
output.println("No matching flow found, run 'flow list' to see your options.", Color.red)
return
} else if (matches.size > 1) {
} else if (matches.size > 1 && matches.find { it.endsWith(nameFragment)} == null) {
output.println("Ambiguous name provided, please be more specific. Your options are:")
matches.forEachIndexed { i, s -> output.println("${i + 1}. $s", Color.yellow) }
return
}
val flowName = matches.find { it.endsWith(nameFragment)} ?: matches.single()
val flowClazz: Class<FlowLogic<*>> = if (classLoader != null) {
uncheckedCast(Class.forName(matches.single(), true, classLoader))
uncheckedCast(Class.forName(flowName, true, classLoader))
} else {
uncheckedCast(Class.forName(matches.single()))
uncheckedCast(Class.forName(flowName))
}
try {
// Show the progress tracker on the console until the flow completes or is interrupted with a
// Ctrl-C keypress.
val stateObservable = runFlowFromString({ clazz, args -> rpcOps.startTrackedFlowDynamic(clazz, *args) }, inputData, flowClazz, om)
val latch = CountDownLatch(1)
latch = CountDownLatch(1)
ansiProgressRenderer.render(stateObservable, latch::countDown)
// Wait for the flow to end and the progress tracker to notice. By the time the latch is released
// the tracker is done with the screen.