diff --git a/node/src/main/kotlin/com/r3corda/node/driver/Driver.kt b/node/src/main/kotlin/com/r3corda/node/driver/Driver.kt index 248dcb0d1a..4a9a11853d 100644 --- a/node/src/main/kotlin/com/r3corda/node/driver/Driver.kt +++ b/node/src/main/kotlin/com/r3corda/node/driver/Driver.kt @@ -7,9 +7,9 @@ import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.services.ServiceType import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.messaging.ArtemisMessagingClient -import com.r3corda.node.services.messaging.ArtemisMessagingComponent import com.r3corda.node.services.network.InMemoryNetworkMapCache import com.r3corda.node.services.network.NetworkMapService +import java.net.ServerSocket import java.net.Socket import java.net.SocketException import java.nio.file.Paths @@ -45,11 +45,33 @@ interface DriverDSLInterface { fun startNode(advertisedServices: Set, providedName: String? = null): NodeInfo } -fun driver(baseDirectory: String? = null, quasarPath: String? = null, dsl: DriverDSL.() -> A): Pair { +sealed class PortAllocation { + abstract fun nextPort(): Int + fun nextHostAndPort() = HostAndPort.fromParts("localhost", nextPort()) + + class Incremental(private var portCounter: Int) : PortAllocation() { + override fun nextPort() = portCounter++ + } + class RandomFree(): PortAllocation() { + override fun nextPort() = ServerSocket(0).use { it.localPort } + } +} + +/** + * TODO: remove quasarJarPath once we have a proper way of bundling quasar + */ +fun driver( + baseDirectory: String = "build/${getTimestampAsDirectoryName()}", + quasarJarPath: String = "lib/quasar.jar", + portAllocation: PortAllocation = PortAllocation.Incremental(10000), + debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005), + dsl: DriverDSL.() -> A +): Pair { val driverDsl = DriverDSL( - portCounter = 10000, - baseDirectory = baseDirectory ?: "build/${getTimestampAsDirectoryName()}", - quasarPath = quasarPath ?: "lib/quasar.jar" + portAllocation = portAllocation, + debugPortAllocation = debugPortAllocation, + baseDirectory = baseDirectory, + quasarJarPath = quasarJarPath ) driverDsl.start() val returnValue = dsl(driverDsl) @@ -95,13 +117,16 @@ fun poll(f: () -> A?): A { return result } -class DriverDSL(private var portCounter: Int, val baseDirectory: String, val quasarPath: String) : DriverDSLInterface { - - fun nextLocalHostAndPort() = HostAndPort.fromParts("localhost", nextPort()) +class DriverDSL( + private val portAllocation: PortAllocation, + private val debugPortAllocation: PortAllocation, + val baseDirectory: String, + val quasarJarPath: String +) : DriverDSLInterface { val networkMapCache = InMemoryNetworkMapCache(null) private val networkMapName = "NetworkMapService" - private val networkMapAddress = nextLocalHostAndPort() + private val networkMapAddress = portAllocation.nextHostAndPort() private lateinit var networkMapNodeInfo: NodeInfo private val registeredProcesses = LinkedList() @@ -115,15 +140,9 @@ class DriverDSL(private var portCounter: Int, val baseDirectory: String, val qua override val trustStorePassword = "trustpass" }, serverHostPort = networkMapAddress, - myHostPort = nextLocalHostAndPort() + myHostPort = portAllocation.nextHostAndPort() ) - private fun nextPort(): Int { - val nextPort = portCounter - portCounter++ - return nextPort - } - fun registerProcess(process: Process) = registeredProcesses.push(process) internal fun waitForAllNodesToFinish() { @@ -151,7 +170,9 @@ class DriverDSL(private var portCounter: Int, val baseDirectory: String, val qua } override fun startNode(advertisedServices: Set, providedName: String?): NodeInfo { - val messagingAddress = nextLocalHostAndPort() + val messagingAddress = portAllocation.nextHostAndPort() + val apiAddress = portAllocation.nextHostAndPort() + val debugPort = debugPortAllocation.nextPort() val name = providedName ?: "${pickA(name)}-${messagingAddress.port}" val nearestCity = pickA(city) @@ -161,12 +182,12 @@ class DriverDSL(private var portCounter: Int, val baseDirectory: String, val qua networkMapPublicKey = networkMapNodeInfo.identity.owningKey, networkMapAddress = networkMapAddress, messagingAddress = messagingAddress, - apiAddress = nextLocalHostAndPort(), + apiAddress = apiAddress, baseDirectory = baseDirectory, nearestCity = nearestCity, legalName = name ) - registerProcess(startNode(driverCliParams, quasarPath)) + registerProcess(startNode(driverCliParams, quasarJarPath, debugPort)) return poll { networkMapCache.partyNodes.forEach { @@ -204,19 +225,21 @@ class DriverDSL(private var portCounter: Int, val baseDirectory: String, val qua } private fun startNetworkMapService() { + val apiAddress = portAllocation.nextHostAndPort() + val debugPort = debugPortAllocation.nextPort() val driverCliParams = NodeRunner.CliParams( services = setOf(NetworkMapService.Type), networkMapName = null, networkMapPublicKey = null, networkMapAddress = null, messagingAddress = networkMapAddress, - apiAddress = nextLocalHostAndPort(), + apiAddress = apiAddress, baseDirectory = baseDirectory, nearestCity = pickA(city), legalName = networkMapName ) println("Starting network-map-service") - registerProcess(startNode(driverCliParams, quasarPath)) + registerProcess(startNode(driverCliParams, quasarJarPath, debugPort)) } companion object { @@ -235,13 +258,15 @@ class DriverDSL(private var portCounter: Int, val baseDirectory: String, val qua ) fun pickA(array: Array): A = array[Math.abs(Random().nextInt()) % array.size] - private fun startNode(cliParams: NodeRunner.CliParams, quasarPath: String): Process { + private fun startNode(cliParams: NodeRunner.CliParams, quasarJarPath: String, debugPort: Int): Process { val className = NodeRunner::class.java.canonicalName val separator = System.getProperty("file.separator") val classpath = System.getProperty("java.class.path") val path = System.getProperty("java.home") + separator + "bin" + separator + "java" val javaArgs = listOf(path) + - listOf("-Dname=${cliParams.legalName}", "-javaagent:$quasarPath", "-cp", classpath, className) + + listOf("-Dname=${cliParams.legalName}", "-javaagent:$quasarJarPath", + "-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=$debugPort", + "-cp", classpath, className) + cliParams.toCliArguments() val builder = ProcessBuilder(javaArgs) builder.redirectError(Paths.get("error.$className.log").toFile())