From 53bbb57345542708dd34b5d5fc62ab34e685a15f Mon Sep 17 00:00:00 2001 From: exfalso <0slemi0@gmail.com> Date: Thu, 15 Dec 2016 11:35:24 +0000 Subject: [PATCH] Add ReceivedSessionMessage, DriverBasedTest re #57 --- .../net/corda/client/CordaRPCClientTest.kt | 27 ++----- .../net/corda/client/NodeMonitorModelTest.kt | 51 +++++-------- docs/source/messaging.rst | 5 +- .../RaftValidatingNotaryServiceTests.kt | 73 ++++++++----------- .../kotlin/net/corda/node/driver/Driver.kt | 34 --------- .../net/corda/node/driver/DriverBasedTest.kt | 39 ++++++++++ .../messaging/ArtemisMessagingServer.kt | 8 +- .../statemachine/FlowStateMachineImpl.kt | 16 ++-- .../statemachine/StateMachineManager.kt | 7 +- 9 files changed, 113 insertions(+), 147 deletions(-) create mode 100644 node/src/main/kotlin/net/corda/node/driver/DriverBasedTest.kt diff --git a/client/src/integration-test/kotlin/net/corda/client/CordaRPCClientTest.kt b/client/src/integration-test/kotlin/net/corda/client/CordaRPCClientTest.kt index 13b23e32fe..0725bf8dfc 100644 --- a/client/src/integration-test/kotlin/net/corda/client/CordaRPCClientTest.kt +++ b/client/src/integration-test/kotlin/net/corda/client/CordaRPCClientTest.kt @@ -8,6 +8,7 @@ import net.corda.core.random63BitValue import net.corda.core.serialization.OpaqueBytes import net.corda.flows.CashCommand import net.corda.flows.CashFlow +import net.corda.node.driver.DriverBasedTest import net.corda.node.driver.NodeHandle import net.corda.node.driver.driver import net.corda.node.services.User @@ -24,32 +25,16 @@ import org.junit.Test import java.util.concurrent.CountDownLatch import kotlin.concurrent.thread -class CordaRPCClientTest { +class CordaRPCClientTest : DriverBasedTest() { private val rpcUser = User("user1", "test", permissions = setOf(startFlowPermission())) - private val stopDriver = CountDownLatch(1) - private var driverThread: Thread? = null private lateinit var client: CordaRPCClient private lateinit var driverInfo: NodeHandle - @Before - fun start() { - val driverStarted = CountDownLatch(1) - driverThread = thread { - driver(isDebug = true) { - driverInfo = startNode(rpcUsers = listOf(rpcUser), advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type))).getOrThrow() - client = CordaRPCClient(toHostAndPort(driverInfo.nodeInfo.address), configureTestSSL()) - driverStarted.countDown() - stopDriver.await() - } - } - driverStarted.await() - } - - @After - fun stop() { - stopDriver.countDown() - driverThread?.join() + override fun setup() = driver(isDebug = true) { + driverInfo = startNode(rpcUsers = listOf(rpcUser), advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type))).getOrThrow() + client = CordaRPCClient(toHostAndPort(driverInfo.nodeInfo.address), configureTestSSL()) + runTest() } @Test diff --git a/client/src/integration-test/kotlin/net/corda/client/NodeMonitorModelTest.kt b/client/src/integration-test/kotlin/net/corda/client/NodeMonitorModelTest.kt index f2b99f2266..25e7d89c15 100644 --- a/client/src/integration-test/kotlin/net/corda/client/NodeMonitorModelTest.kt +++ b/client/src/integration-test/kotlin/net/corda/client/NodeMonitorModelTest.kt @@ -19,7 +19,7 @@ import net.corda.core.serialization.OpaqueBytes import net.corda.core.transactions.SignedTransaction import net.corda.flows.CashCommand import net.corda.flows.CashFlow -import net.corda.node.driver.callSuspendResume +import net.corda.node.driver.DriverBasedTest import net.corda.node.driver.driver import net.corda.node.services.User import net.corda.node.services.config.configureTestSSL @@ -30,16 +30,13 @@ import net.corda.node.services.transactions.SimpleNotaryService import net.corda.testing.expect import net.corda.testing.expectEvents import net.corda.testing.sequence -import org.junit.After -import org.junit.Before import org.junit.Test import rx.Observable import rx.Observer -class NodeMonitorModelTest { +class NodeMonitorModelTest : DriverBasedTest() { lateinit var aliceNode: NodeInfo lateinit var notaryNode: NodeInfo - lateinit var stopDriver: () -> Unit lateinit var stateMachineTransactionMapping: Observable lateinit var stateMachineUpdates: Observable @@ -50,36 +47,26 @@ class NodeMonitorModelTest { lateinit var clientToService: Observer lateinit var newNode: (String) -> NodeInfo - @Before - fun start() { - stopDriver = callSuspendResume { suspend -> - driver { - val cashUser = User("user1", "test", permissions = setOf(startFlowPermission())) - val aliceNodeFuture = startNode("Alice", rpcUsers = listOf(cashUser)) - val notaryNodeFuture = startNode("Notary", advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type))) + override fun setup() = driver { + val cashUser = User("user1", "test", permissions = setOf(startFlowPermission())) + val aliceNodeFuture = startNode("Alice", rpcUsers = listOf(cashUser)) + val notaryNodeFuture = startNode("Notary", advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type))) - aliceNode = aliceNodeFuture.getOrThrow().nodeInfo - notaryNode = notaryNodeFuture.getOrThrow().nodeInfo - newNode = { nodeName -> startNode(nodeName).getOrThrow().nodeInfo } - val monitor = NodeMonitorModel() + aliceNode = aliceNodeFuture.getOrThrow().nodeInfo + notaryNode = notaryNodeFuture.getOrThrow().nodeInfo + newNode = { nodeName -> startNode(nodeName).getOrThrow().nodeInfo } + val monitor = NodeMonitorModel() - stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed() - stateMachineUpdates = monitor.stateMachineUpdates.bufferUntilSubscribed() - progressTracking = monitor.progressTracking.bufferUntilSubscribed() - transactions = monitor.transactions.bufferUntilSubscribed() - vaultUpdates = monitor.vaultUpdates.bufferUntilSubscribed() - networkMapUpdates = monitor.networkMap.bufferUntilSubscribed() - clientToService = monitor.clientToService + stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed() + stateMachineUpdates = monitor.stateMachineUpdates.bufferUntilSubscribed() + progressTracking = monitor.progressTracking.bufferUntilSubscribed() + transactions = monitor.transactions.bufferUntilSubscribed() + vaultUpdates = monitor.vaultUpdates.bufferUntilSubscribed() + networkMapUpdates = monitor.networkMap.bufferUntilSubscribed() + clientToService = monitor.clientToService - monitor.register(ArtemisMessagingComponent.toHostAndPort(aliceNode.address), configureTestSSL(), cashUser.username, cashUser.password) - suspend() - } - } - } - - @After - fun stop() { - stopDriver() + monitor.register(ArtemisMessagingComponent.toHostAndPort(aliceNode.address), configureTestSSL(), cashUser.username, cashUser.password) + runTest() } @Test diff --git a/docs/source/messaging.rst b/docs/source/messaging.rst index 6ef6205eb2..a2139589ac 100644 --- a/docs/source/messaging.rst +++ b/docs/source/messaging.rst @@ -61,8 +61,9 @@ for maintenance and other minor purposes. These are private queues the node may use to route messages to services. The queue name ends in the base 58 encoding of the service's owning identity key. There is at most one queue per service identity (but note that any one service may have several identities). The broker creates bridges to all nodes in the network advertising the service in - question. When a session is initiated with a service counterparty the handshake arrives on this queue, and once a - peer is picked the session continues on as normal. + question. When a session is initiated with a service counterparty the handshake is pushed onto this queue, and a + corresponding bridge is used to forward the message to an advertising peer's p2p queue. Once a peer is picked the + session continues on as normal. :``internal.networkmap``: This is another private queue just for the node which functions in a similar manner to the ``internal.peers.*`` queues diff --git a/node/src/integration-test/kotlin/net/corda/node/services/RaftValidatingNotaryServiceTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/RaftValidatingNotaryServiceTests.kt index c2b3c8080f..f023887c11 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/RaftValidatingNotaryServiceTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/RaftValidatingNotaryServiceTests.kt @@ -12,8 +12,8 @@ import net.corda.core.serialization.OpaqueBytes import net.corda.flows.CashCommand import net.corda.flows.CashFlow import net.corda.flows.CashFlowResult +import net.corda.node.driver.DriverBasedTest import net.corda.node.driver.NodeHandle -import net.corda.node.driver.callSuspendResume import net.corda.node.driver.driver import net.corda.node.services.config.configureTestSSL import net.corda.node.services.messaging.ArtemisMessagingComponent @@ -22,64 +22,51 @@ import net.corda.node.services.transactions.RaftValidatingNotaryService import net.corda.testing.expect import net.corda.testing.expectEvents import net.corda.testing.replicate -import org.junit.After -import org.junit.Before import org.junit.Test import rx.Observable import java.util.* import kotlin.test.assertEquals -class RaftValidatingNotaryServiceTests { - lateinit var stopDriver: () -> Unit +class RaftValidatingNotaryServiceTests : DriverBasedTest() { lateinit var alice: NodeInfo lateinit var notaries: List lateinit var aliceProxy: CordaRPCOps lateinit var raftNotaryIdentity: Party lateinit var notaryStateMachines: Observable> - @Before - fun start() { - stopDriver = callSuspendResume { suspend -> - driver { - // Start Alice and 3 raft notaries - val clusterSize = 3 - val testUser = User("test", "test", permissions = setOf(startFlowPermission())) - val aliceFuture = startNode("Alice", rpcUsers = listOf(testUser)) - val notariesFuture = startNotaryCluster( - "Notary", - rpcUsers = listOf(testUser), - clusterSize = clusterSize, - type = RaftValidatingNotaryService.type - ) + override fun setup() = driver { + // Start Alice and 3 raft notaries + val clusterSize = 3 + val testUser = User("test", "test", permissions = setOf(startFlowPermission())) + val aliceFuture = startNode("Alice", rpcUsers = listOf(testUser)) + val notariesFuture = startNotaryCluster( + "Notary", + rpcUsers = listOf(testUser), + clusterSize = clusterSize, + type = RaftValidatingNotaryService.type + ) - alice = aliceFuture.get().nodeInfo - val (notaryIdentity, notaryNodes) = notariesFuture.get() - raftNotaryIdentity = notaryIdentity - notaries = notaryNodes + alice = aliceFuture.get().nodeInfo + val (notaryIdentity, notaryNodes) = notariesFuture.get() + raftNotaryIdentity = notaryIdentity + notaries = notaryNodes - assertEquals(notaries.size, clusterSize) - assertEquals(notaries.size, notaries.map { it.nodeInfo.legalIdentity }.toSet().size) + assertEquals(notaries.size, clusterSize) + assertEquals(notaries.size, notaries.map { it.nodeInfo.legalIdentity }.toSet().size) - // Connect to Alice and the notaries - fun connectRpc(node: NodeInfo): CordaRPCOps { - val client = CordaRPCClient(ArtemisMessagingComponent.toHostAndPort(node.address), configureTestSSL()) - client.start("test", "test") - return client.proxy() - } - aliceProxy = connectRpc(alice) - val notaryProxies = notaries.map { connectRpc(it.nodeInfo) } - notaryStateMachines = Observable.from(notaryProxies.map { proxy -> - proxy.stateMachinesAndUpdates().second.map { Pair(proxy.nodeIdentity(), it) } - }).flatMap { it.onErrorResumeNext(Observable.empty()) }.bufferUntilSubscribed() - - suspend() - } + // Connect to Alice and the notaries + fun connectRpc(node: NodeInfo): CordaRPCOps { + val client = CordaRPCClient(ArtemisMessagingComponent.toHostAndPort(node.address), configureTestSSL()) + client.start("test", "test") + return client.proxy() } - } + aliceProxy = connectRpc(alice) + val notaryProxies = notaries.map { connectRpc(it.nodeInfo) } + notaryStateMachines = Observable.from(notaryProxies.map { proxy -> + proxy.stateMachinesAndUpdates().second.map { Pair(proxy.nodeIdentity(), it) } + }).flatMap { it.onErrorResumeNext(Observable.empty()) }.bufferUntilSubscribed() - @After - fun stop() { - stopDriver() + runTest() } @Test diff --git a/node/src/main/kotlin/net/corda/node/driver/Driver.kt b/node/src/main/kotlin/net/corda/node/driver/Driver.kt index 8bfafc5dbf..f33b087802 100644 --- a/node/src/main/kotlin/net/corda/node/driver/Driver.kt +++ b/node/src/main/kotlin/net/corda/node/driver/Driver.kt @@ -163,40 +163,6 @@ fun driver( dsl = dsl ) -/** - * Executes the passed in closure in a new thread, providing a function that suspends the closure, passing control back - * to the caller's context. The returned function may be used to then resume the closure. - * - * This can be used in conjunction with the driver to create @Before/@After blocks that start/shutdown the driver: - * - * val stopDriver = callSuspendResume { suspend -> - * driver(someOption = someValue) { - * .. initialise some test variables .. - * suspend() - * } - * } - * .. do tests .. - * stopDriver() - */ -fun callSuspendResume(closure: (suspend: () -> Unit) -> C): () -> C { - val suspendLatch = CountDownLatch(1) - val resumeLatch = CountDownLatch(1) - val returnFuture = CompletableFuture() - thread { - returnFuture.complete( - closure { - suspendLatch.countDown() - resumeLatch.await() - } - ) - } - suspendLatch.await() - return { - resumeLatch.countDown() - returnFuture.get() - } -} - /** * This is a helper method to allow extending of the DSL, along the lines of * interface SomeOtherExposedDSLInterface : DriverDSLExposedInterface diff --git a/node/src/main/kotlin/net/corda/node/driver/DriverBasedTest.kt b/node/src/main/kotlin/net/corda/node/driver/DriverBasedTest.kt new file mode 100644 index 0000000000..63d8bc4d33 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/driver/DriverBasedTest.kt @@ -0,0 +1,39 @@ +package net.corda.node.driver + +import org.junit.After +import org.junit.Before +import java.util.concurrent.CountDownLatch +import kotlin.concurrent.thread + +abstract class DriverBasedTest { + private val stopDriver = CountDownLatch(1) + private var driverThread: Thread? = null + private lateinit var driverStarted: CountDownLatch + + protected sealed class RunTestToken { + internal object Token : RunTestToken() + } + + protected abstract fun setup(): RunTestToken + + protected fun DriverDSLExposedInterface.runTest(): RunTestToken { + driverStarted.countDown() + stopDriver.await() + return RunTestToken.Token + } + + @Before + fun start() { + driverStarted = CountDownLatch(1) + driverThread = thread { + setup() + } + driverStarted.await() + } + + @After + fun stop() { + stopDriver.countDown() + driverThread?.join() + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt index 96d2035e5c..55a8ecc6be 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt @@ -133,21 +133,21 @@ class ArtemisMessagingServer(override val config: NodeConfiguration, } val addressesToCreateBridgesTo = HashSet() - val addressesToRemoveBridgesTo = HashSet() + val addressesToRemoveBridgesFrom = HashSet() when (change) { is MapChange.Modified -> { addAddresses(change.node, addressesToCreateBridgesTo) - addAddresses(change.previousNode, addressesToRemoveBridgesTo) + addAddresses(change.previousNode, addressesToRemoveBridgesFrom) } is MapChange.Removed -> { - addAddresses(change.node, addressesToRemoveBridgesTo) + addAddresses(change.node, addressesToRemoveBridgesFrom) } is MapChange.Added -> { addAddresses(change.node, addressesToCreateBridgesTo) } } - (addressesToRemoveBridgesTo - addressesToCreateBridgesTo).forEach { + (addressesToRemoveBridgesFrom - addressesToCreateBridgesTo).forEach { maybeDestroyBridge(bridgeNameForAddress(it)) } addressesToCreateBridgesTo.forEach { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 4ec2710fd6..e2487cfcc1 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -169,14 +169,14 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, @Suspendable private inline fun receiveInternal(session: FlowSession): M { - return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)).second + return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)).message } private inline fun sendAndReceiveInternal(session: FlowSession, message: SessionMessage): M { - return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)).second + return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)).message } - private inline fun sendAndReceiveInternalWithParty(session: FlowSession, message: SessionMessage): Pair { + private inline fun sendAndReceiveInternalWithParty(session: FlowSession, message: SessionMessage): ReceivedSessionMessage { return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)) } @@ -215,8 +215,8 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - private fun suspendAndExpectReceive(receiveRequest: ReceiveRequest): Pair { - fun getReceivedMessage(): Pair? = receiveRequest.session.receivedMessages.poll() + private fun suspendAndExpectReceive(receiveRequest: ReceiveRequest): ReceivedSessionMessage { + fun getReceivedMessage(): ReceivedSessionMessage? = receiveRequest.session.receivedMessages.poll() val polledMessage = getReceivedMessage() val receivedMessage = if (polledMessage != null) { @@ -232,11 +232,11 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, ?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest") } - if (receivedMessage.second is SessionEnd) { + if (receivedMessage.message is SessionEnd) { openSessions.values.remove(receiveRequest.session) throw FlowSessionException("Counterparty on ${receiveRequest.session.state.sendToParty} has prematurely ended on $receiveRequest") - } else if (receiveRequest.receiveType.isInstance(receivedMessage.second)) { - return Pair(receivedMessage.first, receiveRequest.receiveType.cast(receivedMessage.second)) + } else if (receiveRequest.receiveType.isInstance(receivedMessage.message)) { + return ReceivedSessionMessage(receivedMessage.sendingParty, receiveRequest.receiveType.cast(receivedMessage.message)) } else { throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $receiveRequest") } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 5ec0e726a4..bb15c0d548 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -244,7 +244,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, if (message is SessionEnd) { openSessions.remove(message.recipientSessionId) } - session.receivedMessages += Pair(otherParty, message) + session.receivedMessages += ReceivedSessionMessage(otherParty, message) if (session.waitingForResponse) { // We only want to resume once, so immediately reset the flag. session.waitingForResponse = false @@ -277,7 +277,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, val psm = createFiber(flow) val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(otherParty, otherPartySessionId)) if (sessionInit.firstPayload != null) { - session.receivedMessages += Pair(otherParty, SessionData(session.ourSessionId, sessionInit.firstPayload)) + session.receivedMessages += ReceivedSessionMessage(otherParty, SessionData(session.ourSessionId, sessionInit.firstPayload)) } openSessions[session.ourSessionId] = session psm.openSessions[Pair(flow, otherParty)] = session @@ -453,6 +453,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, serviceHub.networkService.send(sessionTopic, message, address) } + data class ReceivedSessionMessage(val sendingParty: Party, val message: M) interface SessionMessage @@ -509,7 +510,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, var state: FlowSessionState, @Volatile var waitingForResponse: Boolean = false ) { - val receivedMessages = ConcurrentLinkedQueue>() + val receivedMessages = ConcurrentLinkedQueue>() val psm: FlowStateMachineImpl<*> get() = flow.fsm as FlowStateMachineImpl<*> }