diff --git a/node/src/main/kotlin/net/corda/node/driver/Driver.kt b/node/src/main/kotlin/net/corda/node/driver/Driver.kt index 58ec2c0cb7..cc61d7c7b6 100644 --- a/node/src/main/kotlin/net/corda/node/driver/Driver.kt +++ b/node/src/main/kotlin/net/corda/node/driver/Driver.kt @@ -1,15 +1,15 @@ @file:JvmName("Driver") - package net.corda.node.driver import com.google.common.net.HostAndPort -import com.google.common.util.concurrent.Futures -import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture +import com.google.common.util.concurrent.* import com.typesafe.config.Config import com.typesafe.config.ConfigRenderOptions -import net.corda.core.* +import net.corda.core.ThreadBox import net.corda.core.crypto.Party +import net.corda.core.div +import net.corda.core.flatMap +import net.corda.core.map import net.corda.core.messaging.CordaRPCOps import net.corda.core.node.NodeInfo import net.corda.core.node.services.ServiceInfo @@ -92,6 +92,12 @@ interface DriverDSLExposedInterface { */ fun startWebserver(handle: NodeHandle): ListenableFuture + /** + * Starts a network map service node. Note that only a single one should ever be running, so you will probably want + * to set automaticallyStartNetworkMap to false in your [driver] call. + */ + fun startNetworkMapService() + fun waitForAllNodesToFinish() } @@ -161,6 +167,7 @@ fun driver( debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005), systemProperties: Map = emptyMap(), useTestClock: Boolean = false, + automaticallyStartNetworkMap: Boolean = true, dsl: DriverDSLExposedInterface.() -> A ) = genericDriver( driverDsl = DriverDSL( @@ -169,6 +176,7 @@ fun driver( systemProperties = systemProperties, driverDirectory = driverDirectory.toAbsolutePath(), useTestClock = useTestClock, + automaticallyStartNetworkMap = automaticallyStartNetworkMap, isDebug = isDebug ), coerce = { it }, @@ -235,7 +243,7 @@ fun addressMustNotBeBound(executorService: ScheduledExecutorService, hostAndPort } } -private fun poll( +fun poll( executorService: ScheduledExecutorService, pollName: String, pollIntervalMs: Long = 500, @@ -267,21 +275,74 @@ private fun poll( return resultFuture } -open class DriverDSL( +class ShutdownManager(private val executorService: ExecutorService) { + private class State { + val registeredShutdowns = ArrayList Unit>>() + var isShutdown = false + } + private val state = ThreadBox(State()) + + fun shutdown() { + val shutdownFutures = state.locked { + require(!isShutdown) + isShutdown = true + registeredShutdowns + } + val shutdownsFuture = Futures.allAsList(shutdownFutures) + val shutdowns = try { + shutdownsFuture.get(1, SECONDS) + } catch (exception: TimeoutException) { + /** Could not get all of them, collect what we have */ + shutdownFutures.filter { it.isDone }.map { it.get() } + } + shutdowns.reversed().forEach{ it() } + } + + fun registerShutdown(shutdown: ListenableFuture<() -> Unit>) { + state.locked { + require(!isShutdown) + registeredShutdowns.add(shutdown) + } + } + + fun registerProcessShutdown(processFuture: ListenableFuture) { + val processShutdown = processFuture.map { process -> + { + process.destroy() + /** Wait 5 seconds, then [Process.destroyForcibly] */ + val finishedFuture = executorService.submit { + process.waitFor() + } + try { + finishedFuture.get(5, SECONDS) + } catch (exception: TimeoutException) { + finishedFuture.cancel(true) + process.destroyForcibly() + } + Unit + } + } + registerShutdown(processShutdown) + } +} + +class DriverDSL( val portAllocation: PortAllocation, val debugPortAllocation: PortAllocation, val systemProperties: Map, val driverDirectory: Path, val useTestClock: Boolean, - val isDebug: Boolean + val isDebug: Boolean, + val automaticallyStartNetworkMap: Boolean ) : DriverDSLInternalInterface { - private val executorService: ScheduledExecutorService = Executors.newScheduledThreadPool(2) private val networkMapLegalName = "NetworkMapService" private val networkMapAddress = portAllocation.nextHostAndPort() + val executorService: ListeningScheduledExecutorService = MoreExecutors.listeningDecorator(Executors.newScheduledThreadPool(2)) + val shutdownManager = ShutdownManager(executorService) class State { - val registeredProcesses = LinkedList>() val clients = LinkedList() + val processes = ArrayList>() } private val state = ThreadBox(State()) @@ -295,37 +356,24 @@ open class DriverDSL( Paths.get(quasarFileUrl.toURI()).toString() } - fun registerProcess(process: ListenableFuture) = state.locked { registeredProcesses.push(process) } - - override fun waitForAllNodesToFinish() { + fun registerProcess(process: ListenableFuture) { + shutdownManager.registerProcessShutdown(process) state.locked { - registeredProcesses.forEach { - it.getOrThrow().waitFor() - } + processes.add(process) + } + } + + override fun waitForAllNodesToFinish() = state.locked { + Futures.allAsList(processes).get().forEach { + it.waitFor() } } override fun shutdown() { state.locked { clients.forEach(NodeMessagingClient::stop) - registeredProcesses.forEach { - it.get().destroy() - } - } - /** Wait 5 seconds, then [Process.destroyForcibly] */ - val finishedFuture = executorService.submit { - waitForAllNodesToFinish() - } - try { - finishedFuture.get(5, SECONDS) - } catch (exception: TimeoutException) { - finishedFuture.cancel(true) - state.locked { - registeredProcesses.forEach { - it.get().destroyForcibly() - } - } } + shutdownManager.shutdown() // Check that we shut down properly addressMustNotBeBound(executorService, networkMapAddress).get() @@ -458,10 +506,12 @@ open class DriverDSL( } override fun start() { - startNetworkMapService() + if (automaticallyStartNetworkMap) { + startNetworkMapService() + } } - private fun startNetworkMapService(): ListenableFuture { + override fun startNetworkMapService() { val debugPort = if (isDebug) debugPortAllocation.nextPort() else null val apiAddress = portAllocation.nextHostAndPort().toString() val baseDirectory = driverDirectory / networkMapLegalName @@ -481,7 +531,6 @@ open class DriverDSL( log.info("Starting network-map-service") val startNode = startNode(executorService, FullNodeConfiguration(baseDirectory, config), quasarJarPath, debugPort, systemProperties) registerProcess(startNode) - return startNode } companion object { @@ -571,5 +620,5 @@ open class DriverDSL( fun writeConfig(path: Path, filename: String, config: Config) { path.toFile().mkdirs() - File("$path/$filename").writeText(config.root().render(ConfigRenderOptions.concise())) + File("$path/$filename").writeText(config.root().render(ConfigRenderOptions.defaults())) }