diff --git a/testing/node-driver/src/integration-test/kotlin/net/corda/testing/driver/DriverTests.kt b/testing/node-driver/src/integration-test/kotlin/net/corda/testing/driver/DriverTests.kt index dd8a94dc4b..7e9ee5dcf7 100644 --- a/testing/node-driver/src/integration-test/kotlin/net/corda/testing/driver/DriverTests.kt +++ b/testing/node-driver/src/integration-test/kotlin/net/corda/testing/driver/DriverTests.kt @@ -3,6 +3,8 @@ package net.corda.testing.driver import net.corda.core.concurrent.CordaFuture import net.corda.core.identity.CordaX500Name import net.corda.core.internal.CertRole +import net.corda.core.internal.concurrent.fork +import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.concurrent.transpose import net.corda.core.internal.div import net.corda.core.internal.list @@ -11,7 +13,10 @@ import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.getOrThrow import net.corda.node.internal.NodeStartup import net.corda.testing.common.internal.ProjectStructure.projectRootDir -import net.corda.testing.core.* +import net.corda.testing.core.BOB_NAME +import net.corda.testing.core.DUMMY_BANK_A_NAME +import net.corda.testing.core.DUMMY_BANK_B_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME import net.corda.testing.driver.internal.RandomFree import net.corda.testing.http.HttpApi import net.corda.testing.node.NotarySpec @@ -21,9 +26,13 @@ import net.corda.testing.node.internal.internalDriver import org.assertj.core.api.Assertions.* import org.json.simple.JSONObject import org.junit.Test +import java.util.* +import java.util.concurrent.CountDownLatch import java.util.concurrent.Executors +import java.util.concurrent.ForkJoinPool import java.util.concurrent.ScheduledExecutorService import kotlin.streams.toList +import kotlin.test.assertEquals class DriverTests { private companion object { @@ -166,5 +175,31 @@ class DriverTests { } } + + @Test + fun `driver waits for in-process nodes to finish`() { + fun NodeHandle.stopQuietly() = try { + stop() + } catch (t: Throwable) { + t.printStackTrace() + } + + val handlesFuture = openFuture>() + val driverExit = CountDownLatch(1) + val testFuture = ForkJoinPool.commonPool().fork { + val handles = LinkedList(handlesFuture.getOrThrow()) + val last = handles.removeLast() + handles.forEach { it.stopQuietly() } + assertEquals(1, driverExit.count) + last.stopQuietly() + } + driver(DriverParameters(startNodesInProcess = true, waitForAllNodesToFinish = true)) { + val nodeA = newNode(DUMMY_BANK_A_NAME)().getOrThrow() + handlesFuture.set(listOf(nodeA) + notaryHandles.map { it.nodeHandles.getOrThrow() }.flatten()) + } + driverExit.countDown() + testFuture.getOrThrow() + } + private fun DriverDSL.newNode(name: CordaX500Name) = { startNode(NodeParameters(providedName = name)) } } \ No newline at end of file diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/DriverDSLImpl.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/DriverDSLImpl.kt index 80a282fd0a..c561805c88 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/DriverDSLImpl.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/DriverDSLImpl.kt @@ -72,6 +72,7 @@ import java.util.* import java.util.concurrent.Executors import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.TimeUnit +import kotlin.collections.ArrayList import kotlin.collections.HashMap import kotlin.concurrent.thread import net.corda.nodeapi.internal.config.User as InternalUser @@ -107,8 +108,13 @@ class DriverDSLImpl( private lateinit var _notaries: CordaFuture> override val notaryHandles: List get() = _notaries.getOrThrow() + interface Waitable { + @Throws(InterruptedException::class) + fun waitFor(): Unit + } + class State { - val processes = ArrayList() + val processes = ArrayList() } private val state = ThreadBox(State()) @@ -617,20 +623,32 @@ class DriverDSLImpl( } } ) - return nodeAndThreadFuture.flatMap { (node, thread) -> + val nodeFuture: CordaFuture = nodeAndThreadFuture.flatMap { (node, thread) -> establishRpc(config, openFuture()).flatMap { rpc -> visibilityHandle.listen(rpc).map { InProcessImpl(rpc.nodeInfo(), rpc, config.corda, webAddress, useHTTPS, thread, onNodeExit, node) } } } + state.locked { + processes += object : Waitable { + override fun waitFor() { + nodeAndThreadFuture.getOrThrow().second.join() + } + } + } + return nodeFuture } else { val debugPort = if (isDebug) debugPortAllocation.nextPort() else null val monitorPort = if (jmxPolicy.startJmxHttpServer) jmxPolicy.jmxHttpServerPortAllocation?.nextPort() else null val process = startOutOfProcessNode(config, quasarJarPath, debugPort, jolokiaJarPath, monitorPort, systemProperties, cordappPackages, maximumHeapSize) if (waitForAllNodesToFinish) { state.locked { - processes += process + processes += object : Waitable { + override fun waitFor() { + process.waitFor() + } + } } } else { shutdownManager.registerProcessShutdown(process) diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/ShutdownManager.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/ShutdownManager.kt index 6e5f7c02af..433cff4c00 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/ShutdownManager.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/ShutdownManager.kt @@ -3,7 +3,10 @@ package net.corda.testing.node.internal import net.corda.core.concurrent.CordaFuture import net.corda.core.internal.ThreadBox import net.corda.core.internal.concurrent.doneFuture -import net.corda.core.utilities.* +import net.corda.core.utilities.Try +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.seconds import java.util.concurrent.ExecutorService import java.util.concurrent.Executors import java.util.concurrent.TimeoutException @@ -12,6 +15,7 @@ import java.util.concurrent.atomic.AtomicInteger class ShutdownManager(private val executorService: ExecutorService) { private class State { val registeredShutdowns = ArrayList Unit>>() + var isShuttingDown = false var isShutdown = false } @@ -32,6 +36,7 @@ class ShutdownManager(private val executorService: ExecutorService) { } fun shutdown() { + state.locked { isShuttingDown = true } val shutdownActionFutures = state.locked { if (isShutdown) { emptyList Unit>>() @@ -101,4 +106,8 @@ class ShutdownManager(private val executorService: ExecutorService) { } } } + + fun isShuttingDown(): Boolean { + return state.locked { isShuttingDown } + } } \ No newline at end of file