diff --git a/node/src/main/kotlin/com/r3corda/node/driver/Driver.kt b/node/src/main/kotlin/com/r3corda/node/driver/Driver.kt new file mode 100644 index 0000000000..a7c4d6afc9 --- /dev/null +++ b/node/src/main/kotlin/com/r3corda/node/driver/Driver.kt @@ -0,0 +1,364 @@ +package com.r3corda.node.driver + +import com.google.common.net.HostAndPort +import com.r3corda.core.crypto.Party +import com.r3corda.core.crypto.generateKeyPair +import com.r3corda.core.messaging.MessagingService +import com.r3corda.core.node.NodeInfo +import com.r3corda.core.node.services.NetworkMapCache +import com.r3corda.core.node.services.ServiceType +import com.r3corda.node.services.messaging.ArtemisMessagingClient +import com.r3corda.node.services.config.NodeConfigurationFromConfig +import com.r3corda.node.services.config.copy +import com.r3corda.node.services.network.InMemoryNetworkMapCache +import com.r3corda.node.services.network.NetworkMapService +import com.typesafe.config.ConfigFactory +import com.typesafe.config.ConfigParseOptions +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import java.net.ServerSocket +import java.net.Socket +import java.net.SocketException +import java.nio.file.Paths +import java.text.SimpleDateFormat +import java.util.* +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit +import java.util.concurrent.TimeoutException + +/** + * This file defines a small "Driver" DSL for starting up nodes. + * + * The process the driver is run in behaves as an Artemis client and starts up other processes. Namely it first + * bootstraps a network map service to allow the specified nodes to connect to, then starts up the actual nodes. + * + * TODO The driver actually starts up as an Artemis server now that may route traffic. Fix this once the client MessagingService is done. + * TODO The nodes are started up sequentially which is quite slow. Either speed up node startup or make startup parallel somehow. + * TODO The driver now polls the network map cache for info about newly started up nodes, this could be done asynchronously(?). + * TODO The network map service bootstrap is hacky (needs to fake the service's public key in order to retrieve the true one), needs some thought. + */ + +private val log: Logger = LoggerFactory.getLogger(DriverDSL::class.java) + +/** + * This is the interface that's exposed to + */ +interface DriverDSLExposedInterface { + fun startNode(providedName: String? = null, advertisedServices: Set = setOf()): NodeInfo + fun waitForAllNodesToFinish() + val messagingService: MessagingService + val networkMapCache: NetworkMapCache +} + +interface DriverDSLInternalInterface : DriverDSLExposedInterface { + fun start() + fun shutdown() +} + +sealed class PortAllocation { + abstract fun nextPort(): Int + fun nextHostAndPort(): HostAndPort = HostAndPort.fromParts("localhost", nextPort()) + + class Incremental(private var portCounter: Int) : PortAllocation() { + override fun nextPort() = portCounter++ + } + class RandomFree(): PortAllocation() { + override fun nextPort() = ServerSocket(0).use { it.localPort } + } +} + +/** + * [driver] allows one to start up nodes like this: + * driver { + * val noService = startNode("NoService") + * val notary = startNode("Notary") + * + * (...) + * } + * + * The driver implicitly bootstraps a [NetworkMapService] that may be accessed through a local cache [DriverDSL.networkMapCache] + * The driver is an artemis node itself, the messaging service may be accessed by [DriverDSL.messagingService] + * + * @param baseDirectory The base directory node directories go into, defaults to "build//". The node + * directories themselves are "//", where legalName defaults to "-" + * and may be specified in [DriverDSL.startNode]. + * @param nodeConfigurationPath The path to the node's .conf, defaults to "reference.conf". + * @param quasarJarPath The path to quasar.jar, relative to cwd. Defaults to "lib/quasar.jar". TODO remove this once we can bundle quasar properly. + * @param portAllocation The port allocation strategy to use for the messaging and the web server addresses. Defaults to incremental. + * @param debugPortAllocation The port allocation strategy to use for jvm debugging. Defaults to incremental. + * @param dsl The dsl itself + * @return The value returned in the [dsl] closure + */ +fun driver( + baseDirectory: String = "build/${getTimestampAsDirectoryName()}", + nodeConfigurationPath: String = "reference.conf", + quasarJarPath: String = "lib/quasar.jar", + portAllocation: PortAllocation = PortAllocation.Incremental(10000), + debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005), + dsl: DriverDSLExposedInterface.() -> A +) = genericDriver( + driverDsl = DriverDSL( + portAllocation = portAllocation, + debugPortAllocation = debugPortAllocation, + baseDirectory = baseDirectory, + nodeConfigurationPath = nodeConfigurationPath, + quasarJarPath = quasarJarPath + ), + coerce = { it }, + dsl = dsl +) + + +/** + * This is a helper method to allow extending of the DSL, along the lines of + * interface SomeOtherExposedDSLInterface : DriverDSLExposedInterface + * interface SomeOtherInternalDSLInterface : DriverDSLInternalInterface, SomeOtherExposedDSLInterface + * class SomeOtherDSL(val driverDSL : DriverDSL) : DriverDSLInternalInterface by driverDSL, SomeOtherInternalDSLInterface + * + * @param coerce We need this explicit coercion witness because we can't put an extra DI : D bound in a `where` clause + */ +fun genericDriver( + driverDsl: D, + coerce: (D) -> DI, + dsl: DI.() -> A +): A { + var shutdownHook: Thread? = null + try { + driverDsl.start() + val returnValue = dsl(coerce(driverDsl)) + shutdownHook = Thread({ + driverDsl.shutdown() + }) + Runtime.getRuntime().addShutdownHook(shutdownHook) + return returnValue + } finally { + driverDsl.shutdown() + if (shutdownHook != null) { + Runtime.getRuntime().removeShutdownHook(shutdownHook) + } + } +} + +private fun getTimestampAsDirectoryName(): String { + val tz = TimeZone.getTimeZone("UTC") + val df = SimpleDateFormat("yyyyMMddHHmmss") + df.timeZone = tz + return df.format(Date()) +} + +fun addressMustBeBound(hostAndPort: HostAndPort) { + poll { + try { + Socket(hostAndPort.hostText, hostAndPort.port).close() + Unit + } catch (_exception: SocketException) { + null + } + } +} + +fun addressMustNotBeBound(hostAndPort: HostAndPort) { + poll { + try { + Socket(hostAndPort.hostText, hostAndPort.port).close() + null + } catch (_exception: SocketException) { + Unit + } + } +} + +fun poll(f: () -> A?): A { + var counter = 0 + var result = f() + while (result == null && counter < 120) { + counter++ + Thread.sleep(500) + result = f() + } + if (result == null) { + throw Exception("Poll timed out") + } + return result +} + +class DriverDSL( + val portAllocation: PortAllocation, + val debugPortAllocation: PortAllocation, + val baseDirectory: String, + val nodeConfigurationPath: String, + val quasarJarPath: String +) : DriverDSLInternalInterface { + + override val networkMapCache = InMemoryNetworkMapCache() + private val networkMapName = "NetworkMapService" + private val networkMapAddress = portAllocation.nextHostAndPort() + private var networkMapNodeInfo: NodeInfo? = null + private val registeredProcesses = LinkedList() + + val nodeConfiguration = + NodeConfigurationFromConfig( + ConfigFactory.parseResources( + nodeConfigurationPath, + ConfigParseOptions.defaults().setAllowMissing(false) + ) + ).copy( + myLegalName = "driver-artemis" + ) + + override val messagingService = ArtemisMessagingClient( + Paths.get(baseDirectory, "driver-artemis"), + nodeConfiguration, + serverHostPort = networkMapAddress, + myHostPort = portAllocation.nextHostAndPort() + ) + var messagingServiceStarted = false + + fun registerProcess(process: Process) = registeredProcesses.push(process) + + override fun waitForAllNodesToFinish() { + registeredProcesses.forEach { + it.waitFor() + } + } + + override fun shutdown() { + registeredProcesses.forEach { + it.destroy() + } + /** Wait 5 seconds, then [Process.destroyForcibly] */ + val finishedFuture = Executors.newSingleThreadExecutor().submit { + waitForAllNodesToFinish() + } + try { + finishedFuture.get(5, TimeUnit.SECONDS) + } catch (exception: TimeoutException) { + finishedFuture.cancel(true) + registeredProcesses.forEach { + it.destroyForcibly() + } + } + if (messagingServiceStarted){ + messagingService.stop() + } + + // Check that we shut down properly + addressMustNotBeBound(messagingService.myHostPort) + addressMustNotBeBound(networkMapAddress) + } + + /** + * Starts a [Node] in a separate process. + * + * @param providedName Optional name of the node, which will be its legal name in [Party]. Defaults to something + * random. Note that this must be unique as the driver uses it as a primary key! + * @param advertisedServices The set of services to be advertised by the node. Defaults to empty set. + * @return The [NodeInfo] of the started up node retrieved from the network map service. + */ + override fun startNode(providedName: String?, advertisedServices: Set): NodeInfo { + val messagingAddress = portAllocation.nextHostAndPort() + val apiAddress = portAllocation.nextHostAndPort() + val debugPort = debugPortAllocation.nextPort() + val name = providedName ?: "${pickA(name)}-${messagingAddress.port}" + + val driverCliParams = NodeRunner.CliParams( + services = advertisedServices, + networkMapName = networkMapNodeInfo!!.identity.name, + networkMapPublicKey = networkMapNodeInfo!!.identity.owningKey, + networkMapAddress = networkMapAddress, + messagingAddress = messagingAddress, + apiAddress = apiAddress, + baseDirectory = baseDirectory, + nodeConfigurationPath = nodeConfigurationPath, + legalName = name + ) + registerProcess(startNode(driverCliParams, quasarJarPath, debugPort)) + + return poll { + networkMapCache.partyNodes.forEach { + if (it.identity.name == name) { + return@poll it + } + } + null + } + } + + override fun start() { + startNetworkMapService() + messagingService.configureWithDevSSLCertificate() + messagingService.start() + messagingServiceStarted = true + // We fake the network map's NodeInfo with a random public key in order to retrieve the correct NodeInfo from + // the network map service itself + val fakeNodeInfo = NodeInfo( + address = ArtemisMessagingClient.makeRecipient(networkMapAddress), + identity = Party( + name = networkMapName, + owningKey = generateKeyPair().public + ), + advertisedServices = setOf(NetworkMapService.Type) + ) + networkMapCache.addMapService(messagingService, fakeNodeInfo, true) + networkMapNodeInfo = poll { + networkMapCache.partyNodes.forEach { + if (it.identity.name == networkMapName) { + return@poll it + } + } + null + } + } + + private fun startNetworkMapService() { + val apiAddress = portAllocation.nextHostAndPort() + val debugPort = debugPortAllocation.nextPort() + val driverCliParams = NodeRunner.CliParams( + services = setOf(NetworkMapService.Type), + networkMapName = null, + networkMapPublicKey = null, + networkMapAddress = null, + messagingAddress = networkMapAddress, + apiAddress = apiAddress, + baseDirectory = baseDirectory, + nodeConfigurationPath = nodeConfigurationPath, + legalName = networkMapName + ) + log.info("Starting network-map-service") + registerProcess(startNode(driverCliParams, quasarJarPath, debugPort)) + } + + companion object { + + val name = arrayOf( + "Alice", + "Bob", + "EvilBank", + "NotSoEvilBank" + ) + fun pickA(array: Array): A = array[Math.abs(Random().nextInt()) % array.size] + + private fun startNode(cliParams: NodeRunner.CliParams, quasarJarPath: String, debugPort: Int): Process { + val className = NodeRunner::class.java.canonicalName + val separator = System.getProperty("file.separator") + val classpath = System.getProperty("java.class.path") + val path = System.getProperty("java.home") + separator + "bin" + separator + "java" + val javaArgs = listOf(path) + + listOf("-Dname=${cliParams.legalName}", "-javaagent:$quasarJarPath", + "-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=$debugPort", + "-cp", classpath, className) + + cliParams.toCliArguments() + val builder = ProcessBuilder(javaArgs) + builder.redirectError(Paths.get("error.$className.log").toFile()) + builder.inheritIO() + val process = builder.start() + addressMustBeBound(cliParams.messagingAddress) + // TODO There is a race condition here. Even though the messaging address is bound it may be the case that + // the handlers for the advertised services are not yet registered. A hacky workaround is that we wait for + // the web api address to be bound as well, as that starts after the services. Needs rethinking. + addressMustBeBound(cliParams.apiAddress) + + return process + } + } +} diff --git a/node/src/main/kotlin/com/r3corda/node/driver/NodeRunner.kt b/node/src/main/kotlin/com/r3corda/node/driver/NodeRunner.kt new file mode 100644 index 0000000000..ba685619ba --- /dev/null +++ b/node/src/main/kotlin/com/r3corda/node/driver/NodeRunner.kt @@ -0,0 +1,171 @@ +package com.r3corda.node.driver + +import com.google.common.net.HostAndPort +import com.r3corda.core.crypto.Party +import com.r3corda.core.crypto.parsePublicKeyBase58 +import com.r3corda.core.crypto.toBase58String +import com.r3corda.core.node.NodeInfo +import com.r3corda.core.node.services.ServiceType +import com.r3corda.node.internal.Node +import com.r3corda.node.services.messaging.ArtemisMessagingClient +import com.r3corda.node.services.config.NodeConfigurationFromConfig +import com.r3corda.node.services.config.copy +import com.r3corda.node.services.network.NetworkMapService +import com.typesafe.config.ConfigFactory +import com.typesafe.config.ConfigParseOptions +import joptsimple.ArgumentAcceptingOptionSpec +import joptsimple.OptionParser +import joptsimple.OptionSet +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import java.nio.file.Path +import java.nio.file.Paths +import java.security.PublicKey +import java.util.* + +private val log: Logger = LoggerFactory.getLogger(NodeRunner::class.java) + +class NodeRunner { + companion object { + @JvmStatic fun main(arguments: Array) { + val cliParams = CliParams.parse(CliParams.parser.parse(*arguments)) + + val nodeDirectory = Paths.get(cliParams.baseDirectory, cliParams.legalName) + createNodeRunDirectory(nodeDirectory) + + with(cliParams) { + + val networkMapNodeInfo = + if (networkMapName != null && networkMapPublicKey != null && networkMapAddress != null) { + NodeInfo( + address = ArtemisMessagingClient.makeRecipient(networkMapAddress), + identity = Party( + name = networkMapName, + owningKey = networkMapPublicKey + ), + advertisedServices = setOf(NetworkMapService.Type) + ) + } else { + null + } + val nodeConfiguration = + NodeConfigurationFromConfig( + ConfigFactory.parseResources( + nodeConfigurationPath, + ConfigParseOptions.defaults().setAllowMissing(false) + ) + ).copy( + myLegalName = legalName + ) + + val node = Node( + dir = nodeDirectory, + p2pAddr = messagingAddress, + webServerAddr = apiAddress, + configuration = nodeConfiguration, + networkMapAddress = networkMapNodeInfo, + advertisedServices = services.toSet() + ) + + log.info("Starting $legalName with services $services on addresses $messagingAddress and $apiAddress") + node.start() + } + } + } + + class CliParams ( + val services: Set, + val networkMapName: String?, + val networkMapPublicKey: PublicKey?, + val networkMapAddress: HostAndPort?, + val messagingAddress: HostAndPort, + val apiAddress: HostAndPort, + val baseDirectory: String, + val nodeConfigurationPath: String, + val legalName: String + ) { + + companion object { + val parser = OptionParser() + val services = + parser.accepts("services").withRequiredArg().ofType(String::class.java) + val networkMapName = + parser.accepts("network-map-name").withOptionalArg().ofType(String::class.java) + val networkMapPublicKey = + parser.accepts("network-map-public-key").withOptionalArg().ofType(String::class.java) + val networkMapAddress = + parser.accepts("network-map-address").withOptionalArg().ofType(String::class.java) + val messagingAddress = + parser.accepts("messaging-address").withRequiredArg().ofType(String::class.java) + val apiAddress = + parser.accepts("api-address").withRequiredArg().ofType(String::class.java) + val baseDirectory = + parser.accepts("base-directory").withRequiredArg().ofType(String::class.java) + val nodeConfigurationPath = + parser.accepts("node-configuration-path").withRequiredArg().ofType(String::class.java) + val legalName = + parser.accepts("legal-name").withRequiredArg().ofType(String::class.java) + + private fun requiredArgument(optionSet: OptionSet, spec: ArgumentAcceptingOptionSpec) = + optionSet.valueOf(spec) ?: throw IllegalArgumentException("Must provide $spec") + + fun parse(optionSet: OptionSet): CliParams { + val services = optionSet.valuesOf(services) + val networkMapName = optionSet.valueOf(networkMapName) + val networkMapPublicKey = optionSet.valueOf(networkMapPublicKey)?.run { parsePublicKeyBase58(this) } + val networkMapAddress = optionSet.valueOf(networkMapAddress) + val messagingAddress = requiredArgument(optionSet, messagingAddress) + val apiAddress = requiredArgument(optionSet, apiAddress) + val baseDirectory = requiredArgument(optionSet, baseDirectory) + val nodeConfigurationPath = requiredArgument(optionSet, nodeConfigurationPath) + val legalName = requiredArgument(optionSet, legalName) + + return CliParams( + services = services.map { object : ServiceType(it) {} }.toSet(), + messagingAddress = HostAndPort.fromString(messagingAddress), + apiAddress = HostAndPort.fromString(apiAddress), + baseDirectory = baseDirectory, + networkMapName = networkMapName, + networkMapPublicKey = networkMapPublicKey, + networkMapAddress = networkMapAddress?.let { HostAndPort.fromString(it) }, + nodeConfigurationPath = nodeConfigurationPath, + legalName = legalName + ) + } + } + + fun toCliArguments(): List { + val cliArguments = LinkedList() + if (services.isNotEmpty()) { + cliArguments.add("--services") + cliArguments.addAll(services.map { it.toString() }) + } + if (networkMapName != null) { + cliArguments.add("--network-map-name") + cliArguments.add(networkMapName) + } + if (networkMapPublicKey != null) { + cliArguments.add("--network-map-public-key") + cliArguments.add(networkMapPublicKey.toBase58String()) + } + if (networkMapAddress != null) { + cliArguments.add("--network-map-address") + cliArguments.add(networkMapAddress.toString()) + } + cliArguments.add("--messaging-address") + cliArguments.add(messagingAddress.toString()) + cliArguments.add("--api-address") + cliArguments.add(apiAddress.toString()) + cliArguments.add("--base-directory") + cliArguments.add(baseDirectory.toString()) + cliArguments.add("--node-configuration-path") + cliArguments.add(nodeConfigurationPath) + cliArguments.add("--legal-name") + cliArguments.add(legalName) + return cliArguments + } + } +} + +fun createNodeRunDirectory(directory: Path) = directory.toFile().mkdirs() + diff --git a/node/src/main/kotlin/com/r3corda/node/services/messaging/ArtemisMessagingClient.kt b/node/src/main/kotlin/com/r3corda/node/services/messaging/ArtemisMessagingClient.kt index 2b2f5374e1..7b35d4bcdc 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/messaging/ArtemisMessagingClient.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/messaging/ArtemisMessagingClient.kt @@ -65,7 +65,8 @@ class ArtemisMessagingClient(directory: Path, private val mutex = ThreadBox(InnerState()) private val handlers = CopyOnWriteArrayList() - private lateinit var clientFactory: ClientSessionFactory + private var serverLocator: ServerLocator? = null + private var clientFactory: ClientSessionFactory? = null private var session: ClientSession? = null private var consumer: ClientConsumer? = null @@ -86,8 +87,11 @@ class ArtemisMessagingClient(directory: Path, private fun configureAndStartClient() { log.info("Connecting to server: $serverHostPort") // Connect to our server. - clientFactory = ActiveMQClient.createServerLocatorWithoutHA( - tcpTransport(ConnectionDirection.OUTBOUND, serverHostPort.hostText, serverHostPort.port)).createSessionFactory() + val serverLocator = ActiveMQClient.createServerLocatorWithoutHA( + tcpTransport(ConnectionDirection.OUTBOUND, serverHostPort.hostText, serverHostPort.port)) + this.serverLocator = serverLocator + val clientFactory = serverLocator.createSessionFactory() + this.clientFactory = clientFactory // Create a queue on which to receive messages and set up the handler. val session = clientFactory.createSession() @@ -168,6 +172,8 @@ class ArtemisMessagingClient(directory: Path, producers.clear() consumer?.close() session?.close() + clientFactory?.close() + serverLocator?.close() // We expect to be garbage collected shortly after being stopped, so we don't null anything explicitly here. running = false } diff --git a/node/src/main/kotlin/com/r3corda/node/services/messaging/ArtemisMessagingComponent.kt b/node/src/main/kotlin/com/r3corda/node/services/messaging/ArtemisMessagingComponent.kt index cb190f0af2..39231d4e09 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/messaging/ArtemisMessagingComponent.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/messaging/ArtemisMessagingComponent.kt @@ -23,7 +23,7 @@ abstract class ArtemisMessagingComponent(val directory: Path, val config: NodeCo private val trustStorePath = directory.resolve("certificates").resolve("truststore.jks") // In future: can contain onion routing info, etc. - protected data class Address(val hostAndPort: HostAndPort) : SingleMessageRecipient + data class Address(val hostAndPort: HostAndPort) : SingleMessageRecipient protected enum class ConnectionDirection { INBOUND, OUTBOUND } diff --git a/src/main/resources/reference.conf b/node/src/main/resources/reference.conf similarity index 100% rename from src/main/resources/reference.conf rename to node/src/main/resources/reference.conf diff --git a/node/src/test/kotlin/com/r3corda/node/driver/DriverTests.kt b/node/src/test/kotlin/com/r3corda/node/driver/DriverTests.kt new file mode 100644 index 0000000000..0e1ebc6b19 --- /dev/null +++ b/node/src/test/kotlin/com/r3corda/node/driver/DriverTests.kt @@ -0,0 +1,67 @@ +package com.r3corda.node.driver + +import com.google.common.net.HostAndPort +import com.r3corda.core.node.NodeInfo +import com.r3corda.core.node.services.NetworkMapCache +import com.r3corda.node.services.api.RegulatorService +import com.r3corda.node.services.messaging.ArtemisMessagingComponent +import com.r3corda.node.services.transactions.NotaryService +import org.junit.Test + + +class DriverTests { + + companion object { + fun nodeMustBeUp(networkMapCache: NetworkMapCache, nodeInfo: NodeInfo, nodeName: String) { + val address = nodeInfo.address as ArtemisMessagingComponent.Address + // Check that the node is registered in the network map + poll { + networkMapCache.get().firstOrNull { + it.identity.name == nodeName + } + } + // Check that the port is bound + addressMustBeBound(address.hostAndPort) + } + + fun nodeMustBeDown(nodeInfo: NodeInfo) { + val address = nodeInfo.address as ArtemisMessagingComponent.Address + // Check that the port is bound + addressMustNotBeBound(address.hostAndPort) + } + } + + @Test + fun simpleNodeStartupShutdownWorks() { + val (notary, regulator) = driver(quasarJarPath = "../lib/quasar.jar") { + val notary = startNode("TestNotary", setOf(NotaryService.Type)) + val regulator = startNode("Regulator", setOf(RegulatorService.Type)) + + nodeMustBeUp(networkMapCache, notary, "TestNotary") + nodeMustBeUp(networkMapCache, regulator, "Regulator") + Pair(notary, regulator) + } + nodeMustBeDown(notary) + nodeMustBeDown(regulator) + } + + @Test + fun startingNodeWithNoServicesWorks() { + val noService = driver(quasarJarPath = "../lib/quasar.jar") { + val noService = startNode("NoService") + nodeMustBeUp(networkMapCache, noService, "NoService") + noService + } + nodeMustBeDown(noService) + } + + @Test + fun randomFreePortAllocationWorks() { + val nodeInfo = driver(quasarJarPath = "../lib/quasar.jar", portAllocation = PortAllocation.RandomFree()) { + val nodeInfo = startNode("NoService") + nodeMustBeUp(networkMapCache, nodeInfo, "NoService") + nodeInfo + } + nodeMustBeDown(nodeInfo) + } +}