[CORDA-792] Standalone Shell (#2663)

- Existing embedded Shell connects via RPC including checking RPC user credentials (before was a direct use of CordaRPCOps):  in dev mode when console terminal is enabled, node created `shell` user.
- New Standalone Shell app with the same functionalities as Shell: connects to a node via RPC Client,  can use SSL and run SSH server.
This commit is contained in:
szymonsztuka
2018-03-07 09:57:32 +00:00
committed by GitHub
parent 8fe94bca2d
commit 72074c76c7
45 changed files with 1367 additions and 281 deletions

View File

@ -1,173 +0,0 @@
package net.corda.node
import co.paralleluniverse.fibers.Suspendable
import com.jcraft.jsch.ChannelExec
import com.jcraft.jsch.JSch
import com.jcraft.jsch.JSchException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.StartableByRPC
import net.corda.core.identity.Party
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.services.Permissions.Companion.startFlow
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.driver.DriverParameters
import net.corda.testing.node.User
import net.corda.testing.driver.driver
import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.util.io.Streams
import org.junit.Test
import java.net.ConnectException
import java.util.regex.Pattern
import kotlin.test.assertTrue
import kotlin.test.fail
class SSHServerTest {
@Test()
fun `ssh server does not start be default`() {
val user = User("u", "p", setOf())
// The driver will automatically pick up the annotated flows below
driver {
val node = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user))
node.getOrThrow()
val session = JSch().getSession("u", "localhost", 2222)
session.setConfig("StrictHostKeyChecking", "no")
session.setPassword("p")
try {
session.connect()
fail()
} catch (e:JSchException) {
assertTrue(e.cause is ConnectException)
}
}
}
@Test
fun `ssh server starts when configured`() {
val user = User("u", "p", setOf())
// The driver will automatically pick up the annotated flows below
driver {
val node = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user),
customOverrides = mapOf("sshd" to mapOf("port" to 2222)))
node.getOrThrow()
val session = JSch().getSession("u", "localhost", 2222)
session.setConfig("StrictHostKeyChecking", "no")
session.setPassword("p")
session.connect()
assertTrue(session.isConnected)
}
}
@Test
fun `ssh server verify credentials`() {
val user = User("u", "p", setOf())
// The driver will automatically pick up the annotated flows below
driver {
val node = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user),
customOverrides = mapOf("sshd" to mapOf("port" to 2222)))
node.getOrThrow()
val session = JSch().getSession("u", "localhost", 2222)
session.setConfig("StrictHostKeyChecking", "no")
session.setPassword("p_is_bad_password")
try {
session.connect()
fail("Server should reject invalid credentials")
} catch (e: JSchException) {
//There is no specialized exception for this
assertTrue(e.message == "Auth fail")
}
}
}
@Test
fun `ssh respects permissions`() {
val user = User("u", "p", setOf(startFlow<FlowICanRun>()))
// The driver will automatically pick up the annotated flows below
driver(DriverParameters(isDebug = true)) {
val node = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user),
customOverrides = mapOf("sshd" to mapOf("port" to 2222)))
node.getOrThrow()
val session = JSch().getSession("u", "localhost", 2222)
session.setConfig("StrictHostKeyChecking", "no")
session.setPassword("p")
session.connect()
assertTrue(session.isConnected)
val channel = session.openChannel("exec") as ChannelExec
channel.setCommand("start FlowICannotRun otherParty: \"${ALICE_NAME}\"")
channel.connect()
val response = String(Streams.readAll(channel.inputStream))
val flowNameEscaped = Pattern.quote("StartFlow.${SSHServerTest::class.qualifiedName}$${FlowICannotRun::class.simpleName}")
channel.disconnect()
session.disconnect()
assertThat(response).matches("(?s)User not authorized to perform RPC call .*")
}
}
@Test
fun `ssh runs flows`() {
val user = User("u", "p", setOf(startFlow<FlowICanRun>()))
// The driver will automatically pick up the annotated flows below
driver(DriverParameters(isDebug = true)) {
val node = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user),
customOverrides = mapOf("sshd" to mapOf("port" to 2222)))
node.getOrThrow()
val session = JSch().getSession("u", "localhost", 2222)
session.setConfig("StrictHostKeyChecking", "no")
session.setPassword("p")
session.connect()
assertTrue(session.isConnected)
val channel = session.openChannel("exec") as ChannelExec
channel.setCommand("start FlowICanRun")
channel.connect()
val response = String(Streams.readAll(channel.inputStream))
val linesWithDoneCount = response.lines().filter { line -> line.contains("Done") }
// There are ANSI control characters involved, so we want to avoid direct byte to byte matching.
assertThat(linesWithDoneCount).hasSize(1)
}
}
@StartableByRPC
@InitiatingFlow
class FlowICanRun : FlowLogic<String>() {
private val HELLO_STEP = ProgressTracker.Step("Hello")
@Suspendable
override fun call(): String {
progressTracker?.currentStep = HELLO_STEP
return "bambam"
}
override val progressTracker: ProgressTracker? = ProgressTracker(HELLO_STEP)
}
@StartableByRPC
@InitiatingFlow
class FlowICannotRun(val otherParty: Party) : FlowLogic<String>() {
@Suspendable
override fun call(): String = initiateFlow(otherParty).receive<String>().unwrap { it }
override val progressTracker: ProgressTracker? = ProgressTracker()
}
}

View File

@ -5,8 +5,8 @@ import net.corda.client.rpc.internal.createCordaRPCClientWithSsl
import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.getOrThrow
import net.corda.node.services.Permissions.Companion.all
import net.corda.node.testsupport.withCertificates
import net.corda.node.testsupport.withKeyStores
import net.corda.testing.common.internal.withCertificates
import net.corda.testing.common.internal.withKeyStores
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.PortAllocation
import net.corda.testing.driver.driver

View File

@ -1,61 +0,0 @@
package net.corda.node.shell;
// See the comments at the top of run.java
import net.corda.core.messaging.CordaRPCOps;
import net.corda.node.utilities.ANSIProgressRenderer;
import net.corda.node.utilities.CRaSHANSIProgressRenderer;
import org.crsh.cli.*;
import org.crsh.command.*;
import org.crsh.text.*;
import org.crsh.text.ui.TableElement;
import java.util.*;
import static net.corda.node.shell.InteractiveShell.*;
@Man(
"Allows you to start flows, list the ones available and to watch flows currently running on the node.\n\n" +
"Starting flow is the primary way in which you command the node to change the ledger.\n\n" +
"This command is generic, so the right way to use it depends on the flow you wish to start. You can use the 'flow start'\n" +
"command with either a full class name, or a substring of the class name that's unambiguous. The parameters to the \n" +
"flow constructors (the right one is picked automatically) are then specified using the same syntax as for the run command."
)
public class FlowShellCommand extends InteractiveShellCommand {
@Command
@Usage("Start a (work)flow on the node. This is how you can change the ledger.")
public void start(
@Usage("The class name of the flow to run, or an unambiguous substring") @Argument String name,
@Usage("The data to pass as input") @Argument(unquote = false) List<String> input
) {
startFlow(name, input, out, ops(), ansiProgressRenderer());
}
// TODO Limit number of flows shown option?
@Command
@Usage("watch information about state machines running on the node with result information")
public void watch(InvocationContext<TableElement> context) throws Exception {
runStateMachinesView(out, ops());
}
static void startFlow(@Usage("The class name of the flow to run, or an unambiguous substring") @Argument String name,
@Usage("The data to pass as input") @Argument(unquote = false) List<String> input,
RenderPrintWriter out,
CordaRPCOps rpcOps,
ANSIProgressRenderer ansiProgressRenderer) {
if (name == null) {
out.println("You must pass a name for the flow, see 'man flow'", Color.red);
return;
}
String inp = input == null ? "" : String.join(" ", input).trim();
runFlowByNameFragment(name, inp, out, rpcOps, ansiProgressRenderer != null ? ansiProgressRenderer : new CRaSHANSIProgressRenderer(out) );
}
@Command
@Usage("list flows that user can start")
public void list(InvocationContext<String> context) throws Exception {
for (String name : ops().registeredFlows()) {
context.provide(name + System.lineSeparator());
}
}
}

View File

@ -1,56 +0,0 @@
package net.corda.node.shell;
import net.corda.core.messaging.*;
import net.corda.client.jackson.*;
import org.crsh.cli.*;
import org.crsh.command.*;
import java.util.*;
// Note that this class cannot be converted to Kotlin because CRaSH does not understand InvocationContext<Map<?, ?>> which
// is the closest you can get in Kotlin to raw types.
public class RunShellCommand extends InteractiveShellCommand {
@Command
@Man(
"Runs a method from the CordaRPCOps interface, which is the same interface exposed to RPC clients.\n\n" +
"You can learn more about what commands are available by typing 'run' just by itself, or by\n" +
"consulting the developer guide at https://docs.corda.net/api/kotlin/corda/net.corda.core.messaging/-corda-r-p-c-ops/index.html"
)
@Usage("runs a method from the CordaRPCOps interface on the node.")
public Object main(
InvocationContext<Map> context,
@Usage("The command to run") @Argument(unquote = false) List<String> command
) {
StringToMethodCallParser<CordaRPCOps> parser = new StringToMethodCallParser<>(CordaRPCOps.class, objectMapper());
if (command == null) {
emitHelp(context, parser);
return null;
}
return InteractiveShell.runRPCFromString(command, out, context, ops());
}
private void emitHelp(InvocationContext<Map> context, StringToMethodCallParser<CordaRPCOps> parser) {
// Sends data down the pipeline about what commands are available. CRaSH will render it nicely.
// Each element we emit is a map of column -> content.
Map<String, String> cmdsAndArgs = parser.getAvailableCommands();
for (Map.Entry<String, String> entry : cmdsAndArgs.entrySet()) {
// Skip these entries as they aren't really interesting for the user.
if (entry.getKey().equals("startFlowDynamic")) continue;
if (entry.getKey().equals("getProtocolVersion")) continue;
// Use a LinkedHashMap to ensure that the Command column comes first.
Map<String, String> m = new LinkedHashMap<>();
m.put("Command", entry.getKey());
m.put("Parameter types", entry.getValue());
try {
context.provide(m);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}

View File

@ -1,19 +0,0 @@
package net.corda.node.shell;
// A simple forwarder to the "flow start" command, for easier typing.
import net.corda.node.utilities.ANSIProgressRenderer;
import net.corda.node.utilities.CRaSHANSIProgressRenderer;
import org.crsh.cli.*;
import java.util.*;
public class StartShellCommand extends InteractiveShellCommand {
@Command
@Man("An alias for 'flow start'. Example: \"start Yo target: Some other company\"")
public void main(@Usage("The class name of the flow to run, or an unambiguous substring") @Argument String name,
@Usage("The data to pass as input") @Argument(unquote = false) List<String> input) {
ANSIProgressRenderer ansiProgressRenderer = ansiProgressRenderer();
FlowShellCommand.startFlow(name, input, out, ops(), ansiProgressRenderer != null ? ansiProgressRenderer : new CRaSHANSIProgressRenderer(out));
}
}

View File

@ -42,6 +42,7 @@ import net.corda.node.services.FinalityHandler
import net.corda.node.services.NotaryChangeHandler
import net.corda.node.services.api.*
import net.corda.node.services.config.*
import net.corda.node.services.config.shell.toShellConfig
import net.corda.node.services.events.NodeSchedulerService
import net.corda.node.services.events.ScheduledActivityObserver
import net.corda.node.services.identity.PersistentIdentityService
@ -56,7 +57,6 @@ import net.corda.node.services.transactions.*
import net.corda.node.services.upgrade.ContractUpgradeServiceImpl
import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSoftLockManager
import net.corda.node.shell.InteractiveShell
import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.JVMAgentRegistry
import net.corda.node.utilities.NodeBuildProperties
@ -67,6 +67,7 @@ import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.HibernateConfiguration
import net.corda.nodeapi.internal.storeLegalIdentity
import net.corda.tools.shell.InteractiveShell
import org.apache.activemq.artemis.utils.ReusableLatch
import org.hibernate.type.descriptor.java.JavaTypeDescriptorRegistry
import org.slf4j.Logger
@ -258,7 +259,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
tokenizableServices = nodeServices + cordaServices + schedulerService
registerCordappFlows(smm)
_services.rpcFlows += cordappLoader.cordapps.flatMap { it.rpcFlows }
startShell(rpcOps)
startShell()
Pair(StartedNodeImpl(this@AbstractNode, _services, nodeInfo, checkpointStorage, smm, attachments, network, database, rpcOps, flowStarter, notaryService), schedulerService)
}
networkMapUpdater = NetworkMapUpdater(services.networkMapCache,
@ -296,9 +297,12 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
*/
protected abstract fun getRxIoScheduler(): Scheduler
open fun startShell(rpcOps: CordaRPCOps) {
open fun startShell() {
if (configuration.shouldInitCrashShell()) {
InteractiveShell.startShell(configuration, rpcOps, securityManager, _services.identityService, _services.database)
if (configuration.rpcOptions.address == null) {
throw ConfigurationException("Cannot init CrashShell because node RPC address is not set (via 'rpcSettings' option).")
}
InteractiveShell.startShell(configuration.toShellConfig())
}
}

View File

@ -1,6 +1,7 @@
package net.corda.node.internal
import com.codahale.metrics.JmxReporter
import net.corda.client.rpc.internal.KryoClientSerializationScheme
import net.corda.core.concurrent.CordaFuture
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.concurrent.thenMatch
@ -26,9 +27,8 @@ import net.corda.node.internal.security.RPCSecurityManagerImpl
import net.corda.node.serialization.KryoServerSerializationScheme
import net.corda.node.services.api.NodePropertiesStore
import net.corda.node.services.api.SchemaService
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.SecurityConfiguration
import net.corda.node.services.config.VerifierType
import net.corda.node.services.config.*
import net.corda.node.services.config.shell.shellUser
import net.corda.node.services.messaging.*
import net.corda.node.services.rpc.ArtemisRpcBroker
import net.corda.node.services.transactions.InMemoryTransactionVerifierService
@ -159,7 +159,7 @@ open class Node(configuration: NodeConfiguration,
val securityManagerConfig = configuration.security?.authService ?:
SecurityConfiguration.AuthService.fromUsers(configuration.rpcUsers)
securityManager = RPCSecurityManagerImpl(securityManagerConfig)
securityManager = RPCSecurityManagerImpl(if (configuration.shouldInitCrashShell()) securityManagerConfig.copyWithAdditionalUser(configuration.shellUser()) else securityManagerConfig)
val serverAddress = configuration.messagingServerAddress ?: makeLocalMessageBroker(networkParameters)
val rpcServerAddresses = if (configuration.rpcOptions.standAloneBroker) {
@ -373,11 +373,13 @@ open class Node(configuration: NodeConfiguration,
SerializationFactoryImpl().apply {
registerScheme(KryoServerSerializationScheme())
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps))
registerScheme(KryoClientSerializationScheme())
},
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = KRYO_RPC_SERVER_CONTEXT.withClassLoader(classloader),
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader))
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader),
rpcClientContext = if (configuration.shouldInitCrashShell()) KRYO_RPC_CLIENT_CONTEXT.withClassLoader(classloader) else null) //even Shell embeded in the node connects via RPC to the node
}
private var rpcMessagingClient: RPCMessagingClient? = null

View File

@ -12,12 +12,13 @@ import net.corda.node.*
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.NodeConfigurationImpl
import net.corda.node.services.config.shouldStartLocalShell
import net.corda.node.services.config.shouldStartSSHDaemon
import net.corda.node.services.transactions.bftSMaRtSerialFilter
import net.corda.node.shell.InteractiveShell
import net.corda.node.utilities.registration.HTTPNetworkRegistrationService
import net.corda.node.utilities.registration.NetworkRegistrationHelper
import net.corda.nodeapi.internal.addShutdownHook
import net.corda.nodeapi.internal.config.UnknownConfigurationKeysException
import net.corda.tools.shell.InteractiveShell
import org.fusesource.jansi.Ansi
import org.fusesource.jansi.AnsiConsole
import org.slf4j.bridge.SLF4JBridgeHandler
@ -153,12 +154,15 @@ open class NodeStartup(val args: Array<String>) {
if (conf.shouldStartLocalShell()) {
startedNode.internals.startupComplete.then {
try {
InteractiveShell.runLocalShell(startedNode)
InteractiveShell.runLocalShell( {startedNode.dispose()} )
} catch (e: Throwable) {
logger.error("Shell failed to start", e)
}
}
}
if (conf.shouldStartSSHDaemon()) {
Node.printBasicNodeInfo("SSH server listening on port", conf.sshd!!.port.toString())
}
},
{ th ->
logger.error("Unexpected exception during registration", th)

View File

@ -14,6 +14,7 @@ import net.corda.nodeapi.internal.config.SSLConfiguration
import net.corda.nodeapi.internal.config.User
import net.corda.nodeapi.internal.config.parseAs
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.tools.shell.SSHDConfiguration
import java.net.URL
import java.nio.file.Path
import java.time.Duration
@ -253,8 +254,6 @@ data class CertChainPolicyConfig(val role: String, private val policy: CertChain
}
}
data class SSHDConfiguration(val port: Int)
// Supported types of authentication/authorization data providers
enum class AuthDataSourceType {
// External RDBMS
@ -290,6 +289,8 @@ data class SecurityConfiguration(val authService: SecurityConfiguration.AuthServ
}
}
fun copyWithAdditionalUser(user: User) = AuthService(dataSource.copyWithAdditionalUser(user), id, options)
// Optional components: cache
data class Options(val cache: Options.Cache?) {
@ -317,6 +318,12 @@ data class SecurityConfiguration(val authService: SecurityConfiguration.AuthServ
AuthDataSourceType.DB -> require(users == null && connection != null)
}
}
fun copyWithAdditionalUser(user: User) : DataSource{
val extendedList = this.users?.toMutableList()?: mutableListOf()
extendedList.add(user)
return DataSource(this.type, this.passwordEncryption, this.connection, listOf(*extendedList.toTypedArray()))
}
}
companion object {

View File

@ -5,7 +5,6 @@ import java.nio.file.Path
import java.nio.file.Paths
data class SslOptions(override val certificatesDirectory: Path, override val keyStorePassword: String, override val trustStorePassword: String) : SSLConfiguration {
constructor(certificatesDirectory: String, keyStorePassword: String, trustStorePassword: String) : this(certificatesDirectory.toAbsolutePath(), keyStorePassword, trustStorePassword)
fun copy(certificatesDirectory: String = this.certificatesDirectory.toString(), keyStorePassword: String = this.keyStorePassword, trustStorePassword: String = this.trustStorePassword): SslOptions = copy(certificatesDirectory = certificatesDirectory.toAbsolutePath(), keyStorePassword = keyStorePassword, trustStorePassword = trustStorePassword)
}

View File

@ -0,0 +1,44 @@
package net.corda.node.services.config.shell
import net.corda.core.internal.div
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.node.services.Permissions
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.shouldInitCrashShell
import net.corda.nodeapi.internal.config.User
import net.corda.tools.shell.ShellConfiguration
import net.corda.tools.shell.ShellConfiguration.Companion.COMMANDS_DIR
import net.corda.tools.shell.ShellConfiguration.Companion.CORDAPPS_DIR
import net.corda.tools.shell.ShellConfiguration.Companion.SSHD_HOSTKEY_DIR
import net.corda.tools.shell.ShellConfiguration.Companion.SSH_PORT
import net.corda.tools.shell.ShellSslOptions
//re-packs data to Shell specific classes
fun NodeConfiguration.toShellConfig(): ShellConfiguration {
val sslConfiguration = if (this.rpcOptions.useSsl) {
with(this.rpcOptions.sslConfig) {
ShellSslOptions(sslKeystore,
keyStorePassword,
trustStoreFile,
trustStorePassword)
}
} else {
null
}
val localShellUser: User = localShellUser()
return ShellConfiguration(
commandsDirectory = this.baseDirectory / COMMANDS_DIR,
cordappsDirectory = this.baseDirectory.toString() / CORDAPPS_DIR,
user = localShellUser.username,
password = localShellUser.password,
hostAndPort = this.rpcOptions.address ?: NetworkHostAndPort("localhost", SSH_PORT),
ssl = sslConfiguration,
sshdPort = this.sshd?.port,
sshHostKeyDirectory = this.baseDirectory / SSHD_HOSTKEY_DIR,
noLocalShell = this.noLocalShell)
}
private fun localShellUser() = User("shell", "shell", setOf(Permissions.all()))
fun NodeConfiguration.shellUser() = shouldInitCrashShell()?.let { localShellUser() }

View File

@ -1,34 +0,0 @@
package net.corda.node.shell
import net.corda.core.context.Actor
import net.corda.core.context.InvocationContext
import net.corda.core.identity.CordaX500Name
import net.corda.core.messaging.CordaRPCOps
import net.corda.node.internal.security.Password
import net.corda.node.internal.security.RPCSecurityManager
import net.corda.node.internal.security.tryAuthenticate
import org.crsh.auth.AuthInfo
import org.crsh.auth.AuthenticationPlugin
import org.crsh.plugin.CRaSHPlugin
class CordaAuthenticationPlugin(private val rpcOps: CordaRPCOps, private val securityManager: RPCSecurityManager, private val nodeLegalName: CordaX500Name) : CRaSHPlugin<AuthenticationPlugin<String>>(), AuthenticationPlugin<String> {
override fun getImplementation(): AuthenticationPlugin<String> = this
override fun getName(): String = "corda"
override fun authenticate(username: String?, credential: String?): AuthInfo {
if (username == null || credential == null) {
return AuthInfo.UNSUCCESSFUL
}
val authorizingSubject = securityManager.tryAuthenticate(username, Password(credential))
if (authorizingSubject != null) {
val actor = Actor(Actor.Id(username), securityManager.id, nodeLegalName)
return CordaSSHAuthInfo(true, makeRPCOpsWithContext(rpcOps, InvocationContext.rpc(actor), authorizingSubject))
}
return AuthInfo.UNSUCCESSFUL
}
override fun getCredentialType(): Class<String> = String::class.java
}

View File

@ -1,9 +0,0 @@
package net.corda.node.shell
import net.corda.core.messaging.CordaRPCOps
import net.corda.node.utilities.ANSIProgressRenderer
import org.crsh.auth.AuthInfo
class CordaSSHAuthInfo(val successful: Boolean, val rpcOps: CordaRPCOps, val ansiProgressRenderer: ANSIProgressRenderer? = null) : AuthInfo {
override fun isSuccessful(): Boolean = successful
}

View File

@ -1,126 +0,0 @@
package net.corda.node.shell
import net.corda.core.flows.StateMachineRunId
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.context.InvocationContext
import net.corda.core.messaging.StateMachineUpdate
import net.corda.core.messaging.StateMachineUpdate.Added
import net.corda.core.messaging.StateMachineUpdate.Removed
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.Try
import org.crsh.text.Color
import org.crsh.text.Decoration
import org.crsh.text.RenderPrintWriter
import org.crsh.text.ui.LabelElement
import org.crsh.text.ui.Overflow
import org.crsh.text.ui.RowElement
import org.crsh.text.ui.TableElement
import rx.Subscriber
class FlowWatchPrintingSubscriber(private val toStream: RenderPrintWriter) : Subscriber<Any>() {
private val indexMap = HashMap<StateMachineRunId, Int>()
private val table = createStateMachinesTable()
val future = openFuture<Unit>()
init {
// The future is public and can be completed by something else to indicate we don't wish to follow
// anymore (e.g. the user pressing Ctrl-C).
future.then { unsubscribe() }
}
@Synchronized
override fun onCompleted() {
// The observable of state machines will never complete.
future.set(Unit)
}
@Synchronized
override fun onNext(t: Any?) {
if (t is StateMachineUpdate) {
toStream.cls()
createStateMachinesRow(t)
toStream.print(table)
toStream.println("Waiting for completion or Ctrl-C ... ")
toStream.flush()
}
}
@Synchronized
override fun onError(e: Throwable) {
toStream.println("Observable completed with an error")
future.setException(e)
}
private fun stateColor(update: StateMachineUpdate): Color {
return when (update) {
is Added -> Color.blue
is Removed -> if (update.result.isSuccess) Color.green else Color.red
}
}
private fun createStateMachinesTable(): TableElement {
val table = TableElement(1, 2, 1, 2).overflow(Overflow.HIDDEN).rightCellPadding(1)
val header = RowElement(true).add("Id", "Flow name", "Initiator", "Status").style(Decoration.bold.fg(Color.black).bg(Color.white))
table.add(header)
return table
}
// TODO Add progress tracker?
private fun createStateMachinesRow(smmUpdate: StateMachineUpdate) {
when (smmUpdate) {
is Added -> {
table.add(RowElement().add(
LabelElement(formatFlowId(smmUpdate.id)),
LabelElement(formatFlowName(smmUpdate.stateMachineInfo.flowLogicClassName)),
LabelElement(formatInvocationContext(smmUpdate.stateMachineInfo.invocationContext)),
LabelElement("In progress")
).style(stateColor(smmUpdate).fg()))
indexMap[smmUpdate.id] = table.rows.size - 1
}
is Removed -> {
val idx = indexMap[smmUpdate.id]
if (idx != null) {
val oldRow = table.rows[idx]
val flowNameLabel = oldRow.getCol(1) as LabelElement
val flowInitiatorLabel = oldRow.getCol(2) as LabelElement
table.rows[idx] = RowElement().add(
LabelElement(formatFlowId(smmUpdate.id)),
LabelElement(flowNameLabel.value),
LabelElement(flowInitiatorLabel.value),
LabelElement(formatFlowResult(smmUpdate.result))
).style(stateColor(smmUpdate).fg())
}
}
}
}
private fun formatFlowName(flowName: String): String {
val camelCaseRegex = Regex("(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])")
val name = flowName.split('.', '$').last()
// Split CamelCase and get rid of "flow" at the end if present.
return camelCaseRegex.split(name).filter { it.compareTo("Flow", true) != 0 }.joinToString(" ")
}
private fun formatFlowId(flowId: StateMachineRunId): String {
return flowId.toString().removeSurrounding("[", "]")
}
private fun formatInvocationContext(context: InvocationContext): String {
return context.principal().name
}
private fun formatFlowResult(flowResult: Try<*>): String {
fun successFormat(value: Any?): String {
return when (value) {
is SignedTransaction -> "Tx ID: " + value.id.toString()
is kotlin.Unit -> "No return value"
null -> "No return value"
else -> value.toString()
}
}
return when (flowResult) {
is Try.Success -> successFormat(flowResult.value)
is Try.Failure -> flowResult.exception.message ?: flowResult.exception.toString()
}
}
}

View File

@ -1,587 +0,0 @@
package net.corda.node.shell
import com.fasterxml.jackson.core.JsonGenerator
import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.databind.*
import com.fasterxml.jackson.databind.module.SimpleModule
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory
import com.google.common.io.Closeables
import net.corda.client.jackson.JacksonSupport
import net.corda.client.jackson.StringToMethodCallParser
import net.corda.client.rpc.PermissionException
import net.corda.core.CordaException
import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.UniqueIdentifier
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party
import net.corda.core.internal.*
import net.corda.core.internal.concurrent.doneFuture
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.messaging.StateMachineUpdate
import net.corda.core.node.NodeInfo
import net.corda.core.node.services.IdentityService
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.node.internal.Node
import net.corda.node.internal.StartedNode
import net.corda.node.internal.security.AdminSubject
import net.corda.node.internal.security.RPCSecurityManager
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.messaging.CURRENT_RPC_CONTEXT
import net.corda.node.services.messaging.RpcAuthContext
import net.corda.node.utilities.ANSIProgressRenderer
import net.corda.node.utilities.StdoutANSIProgressRenderer
import net.corda.nodeapi.internal.persistence.CordaPersistence
import org.crsh.command.InvocationContext
import org.crsh.console.jline.JLineProcessor
import org.crsh.console.jline.TerminalFactory
import org.crsh.console.jline.console.ConsoleReader
import org.crsh.lang.impl.java.JavaLanguage
import org.crsh.plugin.CRaSHPlugin
import org.crsh.plugin.PluginContext
import org.crsh.plugin.PluginLifeCycle
import org.crsh.plugin.ServiceLoaderDiscovery
import org.crsh.shell.Shell
import org.crsh.shell.ShellFactory
import org.crsh.shell.impl.command.ExternalResolver
import org.crsh.text.Color
import org.crsh.text.RenderPrintWriter
import org.crsh.util.InterruptHandler
import org.crsh.util.Utils
import org.crsh.vfs.FS
import org.crsh.vfs.spi.file.FileMountFactory
import org.crsh.vfs.spi.url.ClassPathMountFactory
import org.json.JSONObject
import org.slf4j.LoggerFactory
import rx.Observable
import rx.Subscriber
import java.io.*
import java.lang.reflect.InvocationTargetException
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ExecutionException
import java.util.concurrent.Future
import kotlin.concurrent.thread
// TODO: Add command history.
// TODO: Command completion.
// TODO: Do something sensible with commands that return a future.
// TODO: Configure default renderers, send objects down the pipeline, add commands to do json/xml/yaml outputs.
// TODO: Add a command to view last N lines/tail/control log4j2 loggers.
// TODO: Review or fix the JVM commands which have bitrotted and some are useless.
// TODO: Get rid of the 'java' command, it's kind of worthless.
// TODO: Fix up the 'dashboard' command which has some rendering issues.
// TODO: Resurrect or reimplement the mail plugin.
// TODO: Make it notice new shell commands added after the node started.
object InteractiveShell {
private val log = LoggerFactory.getLogger(javaClass)
private lateinit var node: StartedNode<Node>
@VisibleForTesting
internal lateinit var database: CordaPersistence
private lateinit var rpcOps: CordaRPCOps
private lateinit var securityManager: RPCSecurityManager
private lateinit var identityService: IdentityService
private var shell: Shell? = null
private lateinit var nodeLegalName: CordaX500Name
/**
* Starts an interactive shell connected to the local terminal. This shell gives administrator access to the node
* internals.
*/
fun startShell(configuration: NodeConfiguration, cordaRPCOps: CordaRPCOps, securityManager: RPCSecurityManager, identityService: IdentityService, database: CordaPersistence) {
this.rpcOps = cordaRPCOps
this.securityManager = securityManager
this.identityService = identityService
this.nodeLegalName = configuration.myLegalName
this.database = database
val dir = configuration.baseDirectory
val runSshDaemon = configuration.sshd != null
val config = Properties()
if (runSshDaemon) {
val sshKeysDir = dir / "sshkey"
sshKeysDir.toFile().mkdirs()
// Enable SSH access. Note: these have to be strings, even though raw object assignments also work.
config["crash.ssh.keypath"] = (sshKeysDir / "hostkey.pem").toString()
config["crash.ssh.keygen"] = "true"
config["crash.ssh.port"] = configuration.sshd?.port.toString()
config["crash.auth"] = "corda"
}
ExternalResolver.INSTANCE.addCommand("run", "Runs a method from the CordaRPCOps interface on the node.", RunShellCommand::class.java)
ExternalResolver.INSTANCE.addCommand("flow", "Commands to work with flows. Flows are how you can change the ledger.", FlowShellCommand::class.java)
ExternalResolver.INSTANCE.addCommand("start", "An alias for 'flow start'", StartShellCommand::class.java)
shell = ShellLifecycle(dir).start(config)
if (runSshDaemon) {
Node.printBasicNodeInfo("SSH server listening on port", configuration.sshd!!.port.toString())
}
}
fun runLocalShell(node: StartedNode<Node>) {
val terminal = TerminalFactory.create()
val consoleReader = ConsoleReader("Corda", FileInputStream(FileDescriptor.`in`), System.out, terminal)
val jlineProcessor = JLineProcessor(terminal.isAnsiSupported, shell, consoleReader, System.out)
InterruptHandler { jlineProcessor.interrupt() }.install()
thread(name = "Command line shell processor", isDaemon = true) {
// Give whoever has local shell access administrator access to the node.
val context = RpcAuthContext(net.corda.core.context.InvocationContext.shell(), AdminSubject("SHELL_USER"))
CURRENT_RPC_CONTEXT.set(context)
Emoji.renderIfSupported {
jlineProcessor.run()
}
}
thread(name = "Command line shell terminator", isDaemon = true) {
// Wait for the shell to finish.
jlineProcessor.closed()
log.info("Command shell has exited")
terminal.restore()
node.dispose()
}
}
class ShellLifecycle(val dir: Path) : PluginLifeCycle() {
fun start(config: Properties): Shell {
val classLoader = this.javaClass.classLoader
val classpathDriver = ClassPathMountFactory(classLoader)
val fileDriver = FileMountFactory(Utils.getCurrentDirectory())
val extraCommandsPath = (dir / "shell-commands").toAbsolutePath().createDirectories()
val commandsFS = FS.Builder()
.register("file", fileDriver)
.mount("file:" + extraCommandsPath)
.register("classpath", classpathDriver)
.mount("classpath:/net/corda/node/shell/")
.mount("classpath:/crash/commands/")
.build()
val confFS = FS.Builder()
.register("classpath", classpathDriver)
.mount("classpath:/crash")
.build()
val discovery = object : ServiceLoaderDiscovery(classLoader) {
override fun getPlugins(): Iterable<CRaSHPlugin<*>> {
// Don't use the Java language plugin (we may not have tools.jar available at runtime), this
// will cause any commands using JIT Java compilation to be suppressed. In CRaSH upstream that
// is only the 'jmx' command.
return super.getPlugins().filterNot { it is JavaLanguage } + CordaAuthenticationPlugin(rpcOps, securityManager, nodeLegalName)
}
}
val attributes = mapOf(
"ops" to rpcOps,
"mapper" to yamlInputMapper
)
val context = PluginContext(discovery, attributes, commandsFS, confFS, classLoader)
context.refresh()
this.config = config
start(context)
return context.getPlugin(ShellFactory::class.java).create(null, CordaSSHAuthInfo(false, makeRPCOpsWithContext(rpcOps, net.corda.core.context.InvocationContext.shell(), AdminSubject("SHELL_USER")), StdoutANSIProgressRenderer))
}
}
private val yamlInputMapper: ObjectMapper by lazy {
// Return a standard Corda Jackson object mapper, configured to use YAML by default and with extra
// serializers.
JacksonSupport.createInMemoryMapper(identityService, YAMLFactory(), true).apply {
val rpcModule = SimpleModule()
rpcModule.addDeserializer(InputStream::class.java, InputStreamDeserializer)
rpcModule.addDeserializer(UniqueIdentifier::class.java, UniqueIdentifierDeserializer)
rpcModule.addDeserializer(UUID::class.java, UUIDDeserializer)
registerModule(rpcModule)
}
}
private object NodeInfoSerializer : JsonSerializer<NodeInfo>() {
override fun serialize(nodeInfo: NodeInfo, gen: JsonGenerator, serializers: SerializerProvider) {
val json = JSONObject()
json["addresses"] = nodeInfo.addresses.map { address -> address.serialise() }
json["legalIdentities"] = nodeInfo.legalIdentities.map { address -> address.serialise() }
json["platformVersion"] = nodeInfo.platformVersion
json["serial"] = nodeInfo.serial
gen.writeRaw(json.toString())
}
private fun NetworkHostAndPort.serialise() = this.toString()
private fun Party.serialise() = JSONObject().put("name", this.name)
private operator fun JSONObject.set(key: String, value: Any?): JSONObject {
return put(key, value)
}
}
private fun createOutputMapper(): ObjectMapper {
return JacksonSupport.createNonRpcMapper().apply {
// Register serializers for stateful objects from libraries that are special to the RPC system and don't
// make sense to print out to the screen. For classes we own, annotations can be used instead.
val rpcModule = SimpleModule()
rpcModule.addSerializer(Observable::class.java, ObservableSerializer)
rpcModule.addSerializer(InputStream::class.java, InputStreamSerializer)
rpcModule.addSerializer(NodeInfo::class.java, NodeInfoSerializer)
registerModule(rpcModule)
disable(SerializationFeature.FAIL_ON_EMPTY_BEANS)
enable(SerializationFeature.INDENT_OUTPUT)
}
}
// TODO: This should become the default renderer rather than something used specifically by commands.
private val outputMapper by lazy { createOutputMapper() }
/**
* Called from the 'flow' shell command. Takes a name fragment and finds a matching flow, or prints out
* the list of options if the request is ambiguous. Then parses [inputData] as constructor arguments using
* the [runFlowFromString] method and starts the requested flow. Ctrl-C can be used to cancel.
*/
@JvmStatic
fun runFlowByNameFragment(nameFragment: String, inputData: String, output: RenderPrintWriter, rpcOps: CordaRPCOps, ansiProgressRenderer: ANSIProgressRenderer) {
val matches = rpcOps.registeredFlows().filter { nameFragment in it }
if (matches.isEmpty()) {
output.println("No matching flow found, run 'flow list' to see your options.", Color.red)
return
} else if (matches.size > 1) {
output.println("Ambiguous name provided, please be more specific. Your options are:")
matches.forEachIndexed { i, s -> output.println("${i + 1}. $s", Color.yellow) }
return
}
val clazz: Class<FlowLogic<*>> = uncheckedCast(Class.forName(matches.single()))
try {
// Show the progress tracker on the console until the flow completes or is interrupted with a
// Ctrl-C keypress.
val stateObservable = runFlowFromString({ clazz, args -> rpcOps.startTrackedFlowDynamic(clazz, *args) }, inputData, clazz)
val latch = CountDownLatch(1)
ansiProgressRenderer.render(stateObservable, { latch.countDown() })
try {
// Wait for the flow to end and the progress tracker to notice. By the time the latch is released
// the tracker is done with the screen.
latch.await()
} catch (e: InterruptedException) {
// TODO: When the flow framework allows us to kill flows mid-flight, do so here.
}
} catch (e: NoApplicableConstructor) {
output.println("No matching constructor found:", Color.red)
e.errors.forEach { output.println("- $it", Color.red) }
} catch (e: PermissionException) {
output.println(e.message ?: "Access denied", Color.red)
} finally {
InputStreamDeserializer.closeAll()
}
}
class NoApplicableConstructor(val errors: List<String>) : CordaException(this.toString()) {
override fun toString() = (listOf("No applicable constructor for flow. Problems were:") + errors).joinToString(System.lineSeparator())
}
// TODO: This utility is generally useful and might be better moved to the node class, or an RPC, if we can commit to making it stable API.
/**
* Given a [FlowLogic] class and a string in one-line Yaml form, finds an applicable constructor and starts
* the flow, returning the created flow logic. Useful for lightweight invocation where text is preferable
* to statically typed, compiled code.
*
* See the [StringToMethodCallParser] class to learn more about limitations and acceptable syntax.
*
* @throws NoApplicableConstructor if no constructor could be found for the given set of types.
*/
@Throws(NoApplicableConstructor::class)
fun <T> runFlowFromString(invoke: (Class<out FlowLogic<T>>, Array<out Any?>) -> FlowProgressHandle<T>,
inputData: String,
clazz: Class<out FlowLogic<T>>,
om: ObjectMapper = yamlInputMapper): FlowProgressHandle<T> {
// For each constructor, attempt to parse the input data as a method call. Use the first that succeeds,
// and keep track of the reasons we failed so we can print them out if no constructors are usable.
val parser = StringToMethodCallParser(clazz, om)
val errors = ArrayList<String>()
for (ctor in clazz.constructors) {
var paramNamesFromConstructor: List<String>? = null
fun getPrototype(): List<String> {
val argTypes = ctor.parameterTypes.map { it.simpleName }
return paramNamesFromConstructor!!.zip(argTypes).map { (name, type) -> "$name: $type" }
}
try {
// Attempt construction with the given arguments.
val args = database.transaction {
paramNamesFromConstructor = parser.paramNamesFromConstructor(ctor)
parser.parseArguments(clazz.name, paramNamesFromConstructor!!.zip(ctor.parameterTypes), inputData)
}
if (args.size != ctor.parameterTypes.size) {
errors.add("${getPrototype()}: Wrong number of arguments (${args.size} provided, ${ctor.parameterTypes.size} needed)")
continue
}
val flow = ctor.newInstance(*args) as FlowLogic<*>
if (flow.progressTracker == null) {
errors.add("A flow must override the progress tracker in order to be run from the shell")
continue
}
return invoke(clazz, args)
} catch (e: StringToMethodCallParser.UnparseableCallException.MissingParameter) {
errors.add("${getPrototype()}: missing parameter ${e.paramName}")
} catch (e: StringToMethodCallParser.UnparseableCallException.TooManyParameters) {
errors.add("${getPrototype()}: too many parameters")
} catch (e: StringToMethodCallParser.UnparseableCallException.ReflectionDataMissing) {
val argTypes = ctor.parameterTypes.map { it.simpleName }
errors.add("$argTypes: <constructor missing parameter reflection data>")
} catch (e: StringToMethodCallParser.UnparseableCallException) {
val argTypes = ctor.parameterTypes.map { it.simpleName }
errors.add("$argTypes: ${e.message}")
}
}
throw NoApplicableConstructor(errors)
}
// TODO Filtering on error/success when we will have some sort of flow auditing, for now it doesn't make much sense.
@JvmStatic
fun runStateMachinesView(out: RenderPrintWriter, rpcOps: CordaRPCOps): Any? {
val proxy = rpcOps
val (stateMachines, stateMachineUpdates) = proxy.stateMachinesFeed()
val currentStateMachines = stateMachines.map { StateMachineUpdate.Added(it) }
val subscriber = FlowWatchPrintingSubscriber(out)
database.transaction {
stateMachineUpdates.startWith(currentStateMachines).subscribe(subscriber)
}
var result: Any? = subscriber.future
if (result is Future<*>) {
if (!result.isDone) {
out.cls()
out.println("Waiting for completion or Ctrl-C ... ")
out.flush()
}
try {
result = result.get()
} catch (e: InterruptedException) {
Thread.currentThread().interrupt()
} catch (e: ExecutionException) {
throw e.rootCause
} catch (e: InvocationTargetException) {
throw e.rootCause
}
}
return result
}
@JvmStatic
fun runRPCFromString(input: List<String>, out: RenderPrintWriter, context: InvocationContext<out Any>, cordaRPCOps: CordaRPCOps): Any? {
val parser = StringToMethodCallParser(CordaRPCOps::class.java, context.attributes["mapper"] as ObjectMapper)
val cmd = input.joinToString(" ").trim { it <= ' ' }
if (cmd.toLowerCase().startsWith("startflow")) {
// The flow command provides better support and startFlow requires special handling anyway due to
// the generic startFlow RPC interface which offers no type information with which to parse the
// string form of the command.
out.println("Please use the 'flow' command to interact with flows rather than the 'run' command.", Color.yellow)
return null
}
var result: Any? = null
try {
InputStreamSerializer.invokeContext = context
val call = database.transaction { parser.parse(cordaRPCOps, cmd) }
result = call.call()
if (result != null && result !is kotlin.Unit && result !is Void) {
result = printAndFollowRPCResponse(result, out)
}
if (result is Future<*>) {
if (!result.isDone) {
out.println("Waiting for completion or Ctrl-C ... ")
out.flush()
}
try {
result = result.get()
} catch (e: InterruptedException) {
Thread.currentThread().interrupt()
} catch (e: ExecutionException) {
throw e.rootCause
} catch (e: InvocationTargetException) {
throw e.rootCause
}
}
} catch (e: StringToMethodCallParser.UnparseableCallException) {
out.println(e.message, Color.red)
out.println("Please try 'man run' to learn what syntax is acceptable")
} catch (e: Exception) {
out.println("RPC failed: ${e.rootCause}", Color.red)
} finally {
InputStreamSerializer.invokeContext = null
InputStreamDeserializer.closeAll()
}
return result
}
private fun printAndFollowRPCResponse(response: Any?, out: PrintWriter): CordaFuture<Unit> {
val mapElement: (Any?) -> String = { element -> outputMapper.writerWithDefaultPrettyPrinter().writeValueAsString(element) }
val mappingFunction: (Any?) -> String = { value ->
if (value is Collection<*>) {
value.joinToString(",${System.lineSeparator()} ", "[${System.lineSeparator()} ", "${System.lineSeparator()}]") { element ->
mapElement(element)
}
} else {
mapElement(value)
}
}
return maybeFollow(response, mappingFunction, out)
}
private class PrintingSubscriber(private val printerFun: (Any?) -> String, private val toStream: PrintWriter) : Subscriber<Any>() {
private var count = 0
val future = openFuture<Unit>()
init {
// The future is public and can be completed by something else to indicate we don't wish to follow
// anymore (e.g. the user pressing Ctrl-C).
future.then { unsubscribe() }
}
@Synchronized
override fun onCompleted() {
toStream.println("Observable has completed")
future.set(Unit)
}
@Synchronized
override fun onNext(t: Any?) {
count++
toStream.println("Observation $count: " + printerFun(t))
toStream.flush()
}
@Synchronized
override fun onError(e: Throwable) {
toStream.println("Observable completed with an error")
e.printStackTrace(toStream)
future.setException(e)
}
}
private fun maybeFollow(response: Any?, printerFun: (Any?) -> String, out: PrintWriter): CordaFuture<Unit> {
// Match on a couple of common patterns for "important" observables. It's tough to do this in a generic
// way because observables can be embedded anywhere in the object graph, and can emit other arbitrary
// object graphs that contain yet more observables. So we just look for top level responses that follow
// the standard "track" pattern, and print them until the user presses Ctrl-C
if (response == null) return doneFuture(Unit)
if (response is DataFeed<*, *>) {
out.println("Snapshot:")
out.println(printerFun(response.snapshot))
out.flush()
out.println("Updates:")
return printNextElements(response.updates, printerFun, out)
}
if (response is Observable<*>) {
return printNextElements(response, printerFun, out)
}
out.println(printerFun(response))
return doneFuture(Unit)
}
private fun printNextElements(elements: Observable<*>, printerFun: (Any?) -> String, out: PrintWriter): CordaFuture<Unit> {
val subscriber = PrintingSubscriber(printerFun, out)
uncheckedCast(elements).subscribe(subscriber)
return subscriber.future
}
//region Extra serializers
//
// These serializers are used to enable the user to specify objects that aren't natural data containers in the shell,
// and for the shell to print things out that otherwise wouldn't be usefully printable.
private object ObservableSerializer : JsonSerializer<Observable<*>>() {
override fun serialize(value: Observable<*>, gen: JsonGenerator, serializers: SerializerProvider) {
gen.writeString("(observable)")
}
}
// A file name is deserialized to an InputStream if found.
object InputStreamDeserializer : JsonDeserializer<InputStream>() {
// Keep track of them so we can close them later.
private val streams = Collections.synchronizedSet(HashSet<InputStream>())
override fun deserialize(p: JsonParser, ctxt: DeserializationContext): InputStream {
val stream = object : BufferedInputStream(Files.newInputStream(Paths.get(p.text))) {
override fun close() {
super.close()
streams.remove(this)
}
}
streams += stream
return stream
}
fun closeAll() {
// Clone the set with toList() here so each closed stream can be removed from the set inside close().
streams.toList().forEach { Closeables.closeQuietly(it) }
}
}
// An InputStream found in a response triggers a request to the user to provide somewhere to save it.
private object InputStreamSerializer : JsonSerializer<InputStream>() {
var invokeContext: InvocationContext<*>? = null
override fun serialize(value: InputStream, gen: JsonGenerator, serializers: SerializerProvider) {
try {
val toPath = invokeContext!!.readLine("Path to save stream to (enter to ignore): ", true)
if (toPath == null || toPath.isBlank()) {
gen.writeString("<not saved>")
} else {
val path = Paths.get(toPath)
value.copyTo(path)
gen.writeString("<saved to: ${path.toAbsolutePath()}>")
}
} finally {
try {
value.close()
} catch (e: IOException) {
// Ignore.
}
}
}
}
/**
* String value deserialized to [UniqueIdentifier].
* Any string value used as [UniqueIdentifier.externalId].
* If string contains underscore(i.e. externalId_uuid) then split with it.
* Index 0 as [UniqueIdentifier.externalId]
* Index 1 as [UniqueIdentifier.id]
* */
object UniqueIdentifierDeserializer : JsonDeserializer<UniqueIdentifier>() {
override fun deserialize(p: JsonParser, ctxt: DeserializationContext): UniqueIdentifier {
//Check if externalId and UUID may be separated by underscore.
if (p.text.contains("_")) {
val ids = p.text.split("_")
//Create UUID object from string.
val uuid: UUID = UUID.fromString(ids[1])
//Create UniqueIdentifier object using externalId and UUID.
return UniqueIdentifier(ids[0], uuid)
}
//Any other string used as externalId.
return UniqueIdentifier.fromString(p.text)
}
}
/**
* String value deserialized to [UUID].
* */
object UUIDDeserializer : JsonDeserializer<UUID>() {
override fun deserialize(p: JsonParser, ctxt: DeserializationContext): UUID {
//Create UUID object from string.
return UUID.fromString(p.text)
}
}
//endregion
}

View File

@ -1,17 +0,0 @@
package net.corda.node.shell
import com.fasterxml.jackson.databind.ObjectMapper
import net.corda.core.messaging.CordaRPCOps
import net.corda.node.services.api.ServiceHubInternal
import org.crsh.command.BaseCommand
import org.crsh.shell.impl.command.CRaSHSession
/**
* Simply extends CRaSH BaseCommand to add easy access to the RPC ops class.
*/
open class InteractiveShellCommand : BaseCommand() {
fun ops() = ((context.session as CRaSHSession).authInfo as CordaSSHAuthInfo).rpcOps
fun ansiProgressRenderer() = ((context.session as CRaSHSession).authInfo as CordaSSHAuthInfo).ansiProgressRenderer
fun services() = context.attributes["services"] as ServiceHubInternal
fun objectMapper() = context.attributes["mapper"] as ObjectMapper
}

View File

@ -1,48 +0,0 @@
package net.corda.node.shell
import net.corda.core.context.InvocationContext
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.utilities.getOrThrow
import net.corda.node.internal.security.AuthorizingSubject
import net.corda.node.services.messaging.CURRENT_RPC_CONTEXT
import net.corda.node.services.messaging.RpcAuthContext
import java.lang.reflect.InvocationTargetException
import java.lang.reflect.Proxy
import java.util.concurrent.CompletableFuture
import java.util.concurrent.Future
fun makeRPCOpsWithContext(cordaRPCOps: CordaRPCOps, invocationContext:InvocationContext, authorizingSubject: AuthorizingSubject) : CordaRPCOps {
return Proxy.newProxyInstance(CordaRPCOps::class.java.classLoader, arrayOf(CordaRPCOps::class.java), { _, method, args ->
RPCContextRunner(invocationContext, authorizingSubject) {
try {
method.invoke(cordaRPCOps, *(args ?: arrayOf()))
} catch (e: InvocationTargetException) {
// Unpack exception.
throw e.targetException
}
}.get().getOrThrow()
}) as CordaRPCOps
}
private class RPCContextRunner<T>(val invocationContext: InvocationContext, val authorizingSubject: AuthorizingSubject, val block:() -> T): Thread() {
private var result: CompletableFuture<T> = CompletableFuture()
override fun run() {
CURRENT_RPC_CONTEXT.set(RpcAuthContext(invocationContext, authorizingSubject))
try {
result.complete(block())
} catch (e: Throwable) {
result.completeExceptionally(e)
} finally {
CURRENT_RPC_CONTEXT.remove()
}
}
fun get(): Future<T> {
start()
join()
return result
}
}

View File

@ -1,263 +0,0 @@
package net.corda.node.utilities
import net.corda.core.internal.Emoji
import net.corda.core.messaging.FlowProgressHandle
import org.apache.logging.log4j.LogManager
import org.apache.logging.log4j.core.LogEvent
import org.apache.logging.log4j.core.LoggerContext
import org.apache.logging.log4j.core.appender.AbstractOutputStreamAppender
import org.apache.logging.log4j.core.appender.ConsoleAppender
import org.apache.logging.log4j.core.appender.OutputStreamManager
import org.crsh.text.RenderPrintWriter
import org.fusesource.jansi.Ansi
import org.fusesource.jansi.AnsiConsole
import org.fusesource.jansi.AnsiOutputStream
import rx.Subscription
abstract class ANSIProgressRenderer {
private var subscriptionIndex: Subscription? = null
private var subscriptionTree: Subscription? = null
protected var usingANSI = false
protected var checkEmoji = false
protected var treeIndex: Int = 0
protected var tree: List<Pair<Int,String>> = listOf()
private var installedYet = false
private var onDone: () -> Unit = {}
// prevMessagePrinted is just for non-ANSI mode.
private var prevMessagePrinted: String? = null
// prevLinesDraw is just for ANSI mode.
protected var prevLinesDrawn = 0
private fun done(error: Throwable?) {
if (error == null) _render(null)
draw(true, error)
onDone()
}
fun render(flowProgressHandle: FlowProgressHandle<*>, onDone: () -> Unit = {}) {
this.onDone = onDone
_render(flowProgressHandle)
}
protected abstract fun printLine(line:String)
protected abstract fun printAnsi(ansi:Ansi)
protected abstract fun setup()
private fun _render(flowProgressHandle: FlowProgressHandle<*>?) {
subscriptionIndex?.unsubscribe()
subscriptionTree?.unsubscribe()
treeIndex = 0
tree = listOf()
if (!installedYet) {
setup()
installedYet = true
}
prevMessagePrinted = null
prevLinesDrawn = 0
draw(true)
flowProgressHandle?.apply {
stepsTreeIndexFeed?.apply {
treeIndex = snapshot
subscriptionIndex = updates.subscribe({
treeIndex = it
draw(true)
}, { done(it) }, { done(null) })
}
stepsTreeFeed?.apply {
tree = snapshot
subscriptionTree = updates.subscribe({
tree = it
draw(true)
}, { done(it) }, { done(null) })
}
}
}
@Synchronized protected fun draw(moveUp: Boolean, error: Throwable? = null) {
if (!usingANSI) {
val currentMessage = tree.getOrNull(treeIndex)?.second
if (currentMessage != null && currentMessage != prevMessagePrinted) {
printLine(currentMessage)
prevMessagePrinted = currentMessage
}
return
}
fun printingBody() {
// Handle the case where the number of steps in a progress tracker is changed during execution.
val ansi = Ansi()
if (prevLinesDrawn > 0 && moveUp)
ansi.cursorUp(prevLinesDrawn)
// Put a blank line between any logging and us.
ansi.eraseLine()
ansi.newline()
if (tree.isEmpty()) return
var newLinesDrawn = 1 + renderLevel(ansi, error != null)
if (error != null) {
ansi.a("${Emoji.skullAndCrossbones} ${error.message}")
ansi.eraseLine(Ansi.Erase.FORWARD)
ansi.newline()
newLinesDrawn++
}
if (newLinesDrawn < prevLinesDrawn) {
// If some steps were removed from the progress tracker, we don't want to leave junk hanging around below.
val linesToClear = prevLinesDrawn - newLinesDrawn
repeat(linesToClear) {
ansi.eraseLine()
ansi.newline()
}
ansi.cursorUp(linesToClear)
}
prevLinesDrawn = newLinesDrawn
printAnsi(ansi)
}
if (checkEmoji) {
Emoji.renderIfSupported(::printingBody)
} else {
printingBody()
}
}
// Returns number of lines rendered.
private fun renderLevel(ansi: Ansi, error: Boolean): Int {
with(ansi) {
var lines = 0
for ((index, step) in tree.withIndex()) {
val marker = when {
index < treeIndex -> "${Emoji.greenTick} "
treeIndex == tree.lastIndex -> "${Emoji.greenTick} "
index == treeIndex -> "${Emoji.rightArrow} "
error -> "${Emoji.noEntry} "
else -> " " // Not reached yet.
}
a(" ".repeat(step.first))
a(marker)
val active = index == treeIndex
if (active) bold()
a(step.second)
if (active) boldOff()
eraseLine(Ansi.Erase.FORWARD)
newline()
lines++
}
return lines
}
}
}
class CRaSHANSIProgressRenderer(val renderPrintWriter:RenderPrintWriter) : ANSIProgressRenderer() {
override fun printLine(line: String) {
renderPrintWriter.println(line)
}
override fun printAnsi(ansi: Ansi) {
renderPrintWriter.print(ansi)
renderPrintWriter.flush()
}
override fun setup() {
// We assume SSH always use ANSI.
usingANSI = true
}
}
/**
* Knows how to render a [FlowProgressHandle] to the terminal using coloured, emoji-fied output. Useful when writing small
* command line tools, demos, tests etc. Just call [draw] method and it will go ahead and start drawing
* if the terminal supports it. Otherwise it just prints out the name of the step whenever it changes.
*
* When a progress tracker is on the screen, it takes over the bottom part and reconfigures logging so that, assuming
* 1 log event == 1 line, the progress tracker is always glued to the bottom and logging scrolls above it.
*
* TODO: More thread safety
*/
object StdoutANSIProgressRenderer : ANSIProgressRenderer() {
override fun setup() {
AnsiConsole.systemInstall()
checkEmoji = true
// This line looks weird as hell because the magic code to decide if we really have a TTY or not isn't
// actually exposed anywhere as a function (weak sauce). So we have to rely on our knowledge of jansi
// implementation details.
usingANSI = AnsiConsole.wrapOutputStream(System.out) !is AnsiOutputStream
if (usingANSI) {
// This super ugly code hacks into log4j and swaps out its console appender for our own. It's a bit simpler
// than doing things the official way with a dedicated plugin, etc, as it avoids mucking around with all
// the config XML and lifecycle goop.
val manager = LogManager.getContext(false) as LoggerContext
val consoleAppender = manager.configuration.appenders.values.filterIsInstance<ConsoleAppender>().single { it.name == "Console-Appender" }
val scrollingAppender = object : AbstractOutputStreamAppender<OutputStreamManager>(
consoleAppender.name, consoleAppender.layout, consoleAppender.filter,
consoleAppender.ignoreExceptions(), true, consoleAppender.manager) {
override fun append(event: LogEvent) {
// We lock on the renderer to avoid threads that are logging to the screen simultaneously messing
// things up. Of course this slows stuff down a bit, but only whilst this little utility is in use.
// Eventually it will be replaced with a real GUI and we can delete all this.
synchronized(StdoutANSIProgressRenderer) {
if (tree.isNotEmpty()) {
val ansi = Ansi.ansi()
repeat(prevLinesDrawn) { ansi.eraseLine().cursorUp(1).eraseLine() }
System.out.print(ansi)
System.out.flush()
}
super.append(event)
if (tree.isNotEmpty())
draw(false)
}
}
}
scrollingAppender.start()
manager.configuration.appenders[consoleAppender.name] = scrollingAppender
val loggerConfigs = manager.configuration.loggers.values
for (config in loggerConfigs) {
val appenderRefs = config.appenderRefs
val consoleAppenders = config.appenders.filter { it.value is ConsoleAppender }.keys
consoleAppenders.forEach { config.removeAppender(it) }
appenderRefs.forEach { config.addAppender(manager.configuration.appenders[it.ref], it.level, it.filter) }
}
manager.updateLoggers()
}
}
override fun printLine(line:String) {
System.out.println(line)
}
override fun printAnsi(ansi: Ansi) {
// Need to force a flush here in order to ensure stderr/stdout sync up properly.
System.out.print(ansi)
System.out.flush()
}
}

View File

@ -1,15 +0,0 @@
package net.corda.node.shell.base
// Note that this file MUST be in a sub-directory called "base" relative to the path
// given in the configuration code in InteractiveShell.
welcome = """
Welcome to the Corda interactive shell.
Useful commands include 'help' to see what is available, and 'bye' to shut down the node.
"""
prompt = { ->
return "${new Date()}>>> "
}

View File

@ -1,98 +0,0 @@
package net.corda.node
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory
import net.corda.client.jackson.JacksonSupport
import net.corda.core.contracts.Amount
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.FlowProgressHandleImpl
import net.corda.core.utilities.ProgressTracker
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.node.shell.InteractiveShell
import net.corda.node.internal.configureDatabase
import net.corda.testing.core.TestIdentity
import net.corda.testing.node.MockServices
import net.corda.testing.node.makeTestIdentityService
import net.corda.testing.internal.rigorousMock
import org.junit.After
import org.junit.Before
import org.junit.Test
import rx.Observable
import java.util.*
import kotlin.test.assertEquals
class InteractiveShellTest {
companion object {
private val megaCorp = TestIdentity(CordaX500Name("MegaCorp", "London", "GB"))
}
@Before
fun setup() {
InteractiveShell.database = configureDatabase(MockServices.makeTestDataSourceProperties(), DatabaseConfig(), rigorousMock())
}
@After
fun shutdown() {
InteractiveShell.database.close()
}
@Suppress("UNUSED")
class FlowA(val a: String) : FlowLogic<String>() {
constructor(b: Int?) : this(b.toString())
constructor(b: Int?, c: String) : this(b.toString() + c)
constructor(amount: Amount<Currency>) : this(amount.toString())
constructor(pair: Pair<Amount<Currency>, SecureHash.SHA256>) : this(pair.toString())
constructor(party: Party) : this(party.name.toString())
override val progressTracker = ProgressTracker()
override fun call() = a
}
private val ids = makeTestIdentityService(megaCorp.identity)
private val om = JacksonSupport.createInMemoryMapper(ids, YAMLFactory())
private fun check(input: String, expected: String) {
var output: String? = null
InteractiveShell.runFlowFromString( { clazz, args ->
val instance = clazz.getConstructor(*args.map { it!!::class.java }.toTypedArray()).newInstance(*args) as FlowA
output = instance.a
val future = openFuture<String>()
future.set("ABC")
FlowProgressHandleImpl(StateMachineRunId.createRandom(), future, Observable.just("Some string"))
}, input, FlowA::class.java, om)
assertEquals(expected, output!!, input)
}
@Test
fun flowStartSimple() {
check("a: Hi there", "Hi there")
check("b: 12", "12")
check("b: 12, c: Yo", "12Yo")
}
@Test
fun flowStartWithComplexTypes() = check("amount: £10", "10.00 GBP")
@Test
fun flowStartWithNestedTypes() = check(
"pair: { first: $100.12, second: df489807f81c8c8829e509e1bcb92e6692b9dd9d624b7456435cb2f51dc82587 }",
"($100.12, df489807f81c8c8829e509e1bcb92e6692b9dd9d624b7456435cb2f51dc82587)"
)
@Test(expected = InteractiveShell.NoApplicableConstructor::class)
fun flowStartNoArgs() = check("", "")
@Test(expected = InteractiveShell.NoApplicableConstructor::class)
fun flowMissingParam() = check("c: Yo", "")
@Test(expected = InteractiveShell.NoApplicableConstructor::class)
fun flowTooManyParams() = check("b: 12, c: Yo, d: Bar", "")
@Test
fun party() = check("party: \"${megaCorp.name}\"", megaCorp.name.toString())
}

View File

@ -2,6 +2,7 @@ package net.corda.node.services.config
import net.corda.core.internal.div
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.tools.shell.SSHDConfiguration
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import org.assertj.core.api.Assertions.assertThatThrownBy

View File

@ -11,12 +11,12 @@ import net.corda.node.internal.security.RPCSecurityManagerImpl
import net.corda.node.services.Permissions.Companion.all
import net.corda.node.services.config.CertChainPolicyConfig
import net.corda.node.services.messaging.RPCMessagingClient
import net.corda.node.testsupport.withCertificates
import net.corda.node.testsupport.withKeyStores
import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport
import net.corda.nodeapi.ConnectionDirection
import net.corda.nodeapi.internal.config.SSLConfiguration
import net.corda.nodeapi.internal.config.User
import net.corda.testing.common.internal.withCertificates
import net.corda.testing.common.internal.withKeyStores
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.driver.PortAllocation
import net.corda.testing.driver.internal.RandomFree

View File

@ -1,72 +0,0 @@
package net.corda.node.shell
import com.fasterxml.jackson.databind.JsonMappingException
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.module.SimpleModule
import com.fasterxml.jackson.module.kotlin.readValue
import net.corda.core.contracts.UniqueIdentifier
import org.junit.Before
import org.junit.Test
import java.util.*
import kotlin.test.assertEquals
class CustomTypeJsonParsingTests {
lateinit var objectMapper: ObjectMapper
//Dummy classes for testing.
data class State(val linearId: UniqueIdentifier) {
constructor() : this(UniqueIdentifier("required-for-json-deserializer"))
}
data class UuidState(val uuid: UUID) {
//Default constructor required for json deserializer.
constructor() : this(UUID.randomUUID())
}
@Before
fun setup() {
objectMapper = ObjectMapper()
val simpleModule = SimpleModule()
simpleModule.addDeserializer(UniqueIdentifier::class.java, InteractiveShell.UniqueIdentifierDeserializer)
simpleModule.addDeserializer(UUID::class.java, InteractiveShell.UUIDDeserializer)
objectMapper.registerModule(simpleModule)
}
@Test
fun `Deserializing UniqueIdentifier by parsing string`() {
val id = "26b37265-a1fd-4c77-b2e0-715917ef619f"
val json = """{"linearId":"$id"}"""
val state = objectMapper.readValue<State>(json)
assertEquals(id, state.linearId.id.toString())
}
@Test
fun `Deserializing UniqueIdentifier by parsing string with underscore`() {
val json = """{"linearId":"extkey564_26b37265-a1fd-4c77-b2e0-715917ef619f"}"""
val state = objectMapper.readValue<State>(json)
assertEquals("extkey564", state.linearId.externalId)
assertEquals("26b37265-a1fd-4c77-b2e0-715917ef619f", state.linearId.id.toString())
}
@Test(expected = JsonMappingException::class)
fun `Deserializing by parsing string contain invalid uuid with underscore`() {
val json = """{"linearId":"extkey564_26b37265-a1fd-4c77-b2e0"}"""
objectMapper.readValue<State>(json)
}
@Test
fun `Deserializing UUID by parsing string`() {
val json = """{"uuid":"26b37265-a1fd-4c77-b2e0-715917ef619f"}"""
val state = objectMapper.readValue<UuidState>(json)
assertEquals("26b37265-a1fd-4c77-b2e0-715917ef619f", state.uuid.toString())
}
@Test(expected = JsonMappingException::class)
fun `Deserializing UUID by parsing invalid uuid string`() {
val json = """{"uuid":"26b37265-a1fd-4c77-b2e0"}"""
objectMapper.readValue<UuidState>(json)
}
}

View File

@ -1,206 +0,0 @@
package net.corda.node.testsupport
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div
import net.corda.node.services.config.SslOptions
import net.corda.nodeapi.internal.crypto.*
import org.apache.commons.io.FileUtils
import sun.security.tools.keytool.CertAndKeyGen
import sun.security.x509.X500Name
import java.nio.file.Files
import java.nio.file.Path
import java.security.KeyPair
import java.security.KeyStore
import java.security.PrivateKey
import java.security.cert.X509Certificate
import java.time.Duration
import java.time.Instant
import java.time.Instant.now
import java.time.temporal.ChronoUnit
import java.util.*
import javax.security.auth.x500.X500Principal
class UnsafeCertificatesFactory(
defaults: Defaults = defaults(),
private val keyType: String = defaults.keyType,
private val signatureAlgorithm: String = defaults.signatureAlgorithm,
private val keySize: Int = defaults.keySize,
private val certificatesValidityWindow: CertificateValidityWindow = defaults.certificatesValidityWindow,
private val keyStoreType: String = defaults.keyStoreType) {
companion object {
private const val KEY_TYPE_RSA = "RSA"
private const val SIG_ALG_SHA_RSA = "SHA1WithRSA"
private const val KEY_SIZE = 1024
private val DEFAULT_DURATION = Duration.of(365, ChronoUnit.DAYS)
private const val DEFAULT_KEYSTORE_TYPE = "JKS"
fun defaults() = Defaults(KEY_TYPE_RSA, SIG_ALG_SHA_RSA, KEY_SIZE, CertificateValidityWindow(now(), DEFAULT_DURATION), DEFAULT_KEYSTORE_TYPE)
}
data class Defaults(
val keyType: String,
val signatureAlgorithm: String,
val keySize: Int,
val certificatesValidityWindow: CertificateValidityWindow,
val keyStoreType: String)
fun createSelfSigned(name: X500Name): UnsafeCertificate = createSelfSigned(name, keyType, signatureAlgorithm, keySize, certificatesValidityWindow)
fun createSelfSigned(name: CordaX500Name) = createSelfSigned(name.asX500Name())
fun createSignedBy(subject: X500Principal, issuer: UnsafeCertificate): UnsafeCertificate = issuer.createSigned(subject, keyType, signatureAlgorithm, keySize, certificatesValidityWindow)
fun createSignedBy(name: CordaX500Name, issuer: UnsafeCertificate): UnsafeCertificate = issuer.createSigned(name, keyType, signatureAlgorithm, keySize, certificatesValidityWindow)
fun newKeyStore(password: String) = UnsafeKeyStore.create(keyStoreType, password)
fun newKeyStores(keyStorePassword: String, trustStorePassword: String): KeyStores = KeyStores(newKeyStore(keyStorePassword), newKeyStore(trustStorePassword))
}
class KeyStores(val keyStore: UnsafeKeyStore, val trustStore: UnsafeKeyStore) {
fun save(directory: Path = Files.createTempDirectory(null)): AutoClosableSSLConfiguration {
val keyStoreFile = keyStore.toTemporaryFile("sslkeystore", directory = directory)
val trustStoreFile = trustStore.toTemporaryFile("truststore", directory = directory)
val sslConfiguration = sslConfiguration(directory)
return object : AutoClosableSSLConfiguration {
override val value = sslConfiguration
override fun close() {
keyStoreFile.close()
trustStoreFile.close()
}
}
}
private fun sslConfiguration(directory: Path) = SslOptions(directory, keyStore.password, trustStore.password)
}
interface AutoClosableSSLConfiguration : AutoCloseable {
val value: SslOptions
}
typealias KeyStoreEntry = Pair<String, UnsafeCertificate>
data class UnsafeKeyStore(private val delegate: KeyStore, val password: String) : Iterable<KeyStoreEntry> {
companion object {
private const val JKS_TYPE = "JKS"
fun create(type: String, password: String) = UnsafeKeyStore(newKeyStore(type, password), password)
fun createJKS(password: String) = create(JKS_TYPE, password)
}
operator fun plus(entry: KeyStoreEntry) = set(entry.first, entry.second)
override fun iterator(): Iterator<Pair<String, UnsafeCertificate>> = delegate.aliases().toList().map { alias -> alias to get(alias) }.iterator()
operator fun get(alias: String): UnsafeCertificate {
return when {
delegate.isKeyEntry(alias) -> delegate.getCertificateAndKeyPair(alias, password).unsafe()
else -> UnsafeCertificate(delegate.getX509Certificate(alias), null)
}
}
operator fun set(alias: String, certificate: UnsafeCertificate) {
delegate.setCertificateEntry(alias, certificate.value)
delegate.setKeyEntry(alias, certificate.privateKey, password.toCharArray(), arrayOf(certificate.value))
}
fun save(path: Path) = delegate.save(path, password)
fun toTemporaryFile(fileName: String, fileExtension: String? = delegate.type.toLowerCase(), directory: Path): TemporaryFile {
return TemporaryFile("$fileName.$fileExtension", directory).also { save(it.path) }
}
}
class TemporaryFile(fileName: String, val directory: Path) : AutoCloseable {
private val file = (directory / fileName).toFile()
init {
file.createNewFile()
file.deleteOnExit()
}
val path: Path = file.toPath().toAbsolutePath()
override fun close() = FileUtils.forceDelete(file)
}
data class UnsafeCertificate(val value: X509Certificate, val privateKey: PrivateKey?) {
val keyPair = KeyPair(value.publicKey, privateKey)
val principal: X500Principal get() = value.subjectX500Principal
val issuer: X500Principal get() = value.issuerX500Principal
fun createSigned(subject: X500Principal, keyType: String, signatureAlgorithm: String, keySize: Int, certificatesValidityWindow: CertificateValidityWindow): UnsafeCertificate {
val keyGen = keyGen(keyType, signatureAlgorithm, keySize)
return UnsafeCertificate(X509Utilities.createCertificate(
certificateType = CertificateType.TLS,
issuer = value.subjectX500Principal,
issuerKeyPair = keyPair,
validityWindow = certificatesValidityWindow.datePair,
subject = subject,
subjectPublicKey = keyGen.publicKey
), keyGen.privateKey)
}
fun createSigned(name: CordaX500Name, keyType: String, signatureAlgorithm: String, keySize: Int, certificatesValidityWindow: CertificateValidityWindow) = createSigned(name.x500Principal, keyType, signatureAlgorithm, keySize, certificatesValidityWindow)
}
data class CertificateValidityWindow(val from: Instant, val to: Instant) {
constructor(from: Instant, duration: Duration) : this(from, from.plus(duration))
val duration = Duration.between(from, to)!!
val datePair = Date.from(from) to Date.from(to)
}
private fun createSelfSigned(name: X500Name, keyType: String, signatureAlgorithm: String, keySize: Int, certificatesValidityWindow: CertificateValidityWindow): UnsafeCertificate {
val keyGen = keyGen(keyType, signatureAlgorithm, keySize)
return UnsafeCertificate(keyGen.getSelfCertificate(name, certificatesValidityWindow.duration.toMillis()), keyGen.privateKey)
}
private fun CordaX500Name.asX500Name(): X500Name = X500Name.asX500Name(x500Principal)
private fun CertificateAndKeyPair.unsafe() = UnsafeCertificate(certificate, keyPair.private)
private fun keyGen(keyType: String, signatureAlgorithm: String, keySize: Int): CertAndKeyGen {
val keyGen = CertAndKeyGen(keyType, signatureAlgorithm)
keyGen.generate(keySize)
return keyGen
}
private fun newKeyStore(type: String, password: String): KeyStore {
val keyStore = KeyStore.getInstance(type)
// Loading creates the store, can't do anything with it until it's loaded
keyStore.load(null, password.toCharArray())
return keyStore
}
fun withKeyStores(server: KeyStores, client: KeyStores, action: (brokerSslOptions: SslOptions, clientSslOptions: SslOptions) -> Unit) {
val serverDir = Files.createTempDirectory(null)
FileUtils.forceDeleteOnExit(serverDir.toFile())
val clientDir = Files.createTempDirectory(null)
FileUtils.forceDeleteOnExit(clientDir.toFile())
server.save(serverDir).use { serverSslConfiguration ->
client.save(clientDir).use { clientSslConfiguration ->
action(serverSslConfiguration.value, clientSslConfiguration.value)
}
}
FileUtils.deleteQuietly(clientDir.toFile())
FileUtils.deleteQuietly(serverDir.toFile())
}
fun withCertificates(factoryDefaults: UnsafeCertificatesFactory.Defaults = UnsafeCertificatesFactory.defaults(), action: (server: KeyStores, client: KeyStores, createSelfSigned: (name: CordaX500Name) -> UnsafeCertificate, createSignedBy: (name: CordaX500Name, issuer: UnsafeCertificate) -> UnsafeCertificate) -> Unit) {
val factory = UnsafeCertificatesFactory(factoryDefaults)
val server = factory.newKeyStores("serverKeyStorePass", "serverTrustKeyStorePass")
val client = factory.newKeyStores("clientKeyStorePass", "clientTrustKeyStorePass")
action(server, client, factory::createSelfSigned, factory::createSignedBy)
}