Add ReceivedSessionMessage, DriverBasedTest re #57

This commit is contained in:
exfalso
2016-12-15 11:35:24 +00:00
parent 8ea4c258f1
commit 53bbb57345
9 changed files with 113 additions and 147 deletions

View File

@ -8,6 +8,7 @@ import net.corda.core.random63BitValue
import net.corda.core.serialization.OpaqueBytes import net.corda.core.serialization.OpaqueBytes
import net.corda.flows.CashCommand import net.corda.flows.CashCommand
import net.corda.flows.CashFlow import net.corda.flows.CashFlow
import net.corda.node.driver.DriverBasedTest
import net.corda.node.driver.NodeHandle import net.corda.node.driver.NodeHandle
import net.corda.node.driver.driver import net.corda.node.driver.driver
import net.corda.node.services.User import net.corda.node.services.User
@ -24,32 +25,16 @@ import org.junit.Test
import java.util.concurrent.CountDownLatch import java.util.concurrent.CountDownLatch
import kotlin.concurrent.thread import kotlin.concurrent.thread
class CordaRPCClientTest { class CordaRPCClientTest : DriverBasedTest() {
private val rpcUser = User("user1", "test", permissions = setOf(startFlowPermission<CashFlow>())) private val rpcUser = User("user1", "test", permissions = setOf(startFlowPermission<CashFlow>()))
private val stopDriver = CountDownLatch(1)
private var driverThread: Thread? = null
private lateinit var client: CordaRPCClient private lateinit var client: CordaRPCClient
private lateinit var driverInfo: NodeHandle private lateinit var driverInfo: NodeHandle
@Before override fun setup() = driver(isDebug = true) {
fun start() { driverInfo = startNode(rpcUsers = listOf(rpcUser), advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type))).getOrThrow()
val driverStarted = CountDownLatch(1) client = CordaRPCClient(toHostAndPort(driverInfo.nodeInfo.address), configureTestSSL())
driverThread = thread { runTest()
driver(isDebug = true) {
driverInfo = startNode(rpcUsers = listOf(rpcUser), advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type))).getOrThrow()
client = CordaRPCClient(toHostAndPort(driverInfo.nodeInfo.address), configureTestSSL())
driverStarted.countDown()
stopDriver.await()
}
}
driverStarted.await()
}
@After
fun stop() {
stopDriver.countDown()
driverThread?.join()
} }
@Test @Test

View File

@ -19,7 +19,7 @@ import net.corda.core.serialization.OpaqueBytes
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.flows.CashCommand import net.corda.flows.CashCommand
import net.corda.flows.CashFlow import net.corda.flows.CashFlow
import net.corda.node.driver.callSuspendResume import net.corda.node.driver.DriverBasedTest
import net.corda.node.driver.driver import net.corda.node.driver.driver
import net.corda.node.services.User import net.corda.node.services.User
import net.corda.node.services.config.configureTestSSL import net.corda.node.services.config.configureTestSSL
@ -30,16 +30,13 @@ import net.corda.node.services.transactions.SimpleNotaryService
import net.corda.testing.expect import net.corda.testing.expect
import net.corda.testing.expectEvents import net.corda.testing.expectEvents
import net.corda.testing.sequence import net.corda.testing.sequence
import org.junit.After
import org.junit.Before
import org.junit.Test import org.junit.Test
import rx.Observable import rx.Observable
import rx.Observer import rx.Observer
class NodeMonitorModelTest { class NodeMonitorModelTest : DriverBasedTest() {
lateinit var aliceNode: NodeInfo lateinit var aliceNode: NodeInfo
lateinit var notaryNode: NodeInfo lateinit var notaryNode: NodeInfo
lateinit var stopDriver: () -> Unit
lateinit var stateMachineTransactionMapping: Observable<StateMachineTransactionMapping> lateinit var stateMachineTransactionMapping: Observable<StateMachineTransactionMapping>
lateinit var stateMachineUpdates: Observable<StateMachineUpdate> lateinit var stateMachineUpdates: Observable<StateMachineUpdate>
@ -50,36 +47,26 @@ class NodeMonitorModelTest {
lateinit var clientToService: Observer<CashCommand> lateinit var clientToService: Observer<CashCommand>
lateinit var newNode: (String) -> NodeInfo lateinit var newNode: (String) -> NodeInfo
@Before override fun setup() = driver {
fun start() { val cashUser = User("user1", "test", permissions = setOf(startFlowPermission<CashFlow>()))
stopDriver = callSuspendResume { suspend -> val aliceNodeFuture = startNode("Alice", rpcUsers = listOf(cashUser))
driver { val notaryNodeFuture = startNode("Notary", advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type)))
val cashUser = User("user1", "test", permissions = setOf(startFlowPermission<CashFlow>()))
val aliceNodeFuture = startNode("Alice", rpcUsers = listOf(cashUser))
val notaryNodeFuture = startNode("Notary", advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type)))
aliceNode = aliceNodeFuture.getOrThrow().nodeInfo aliceNode = aliceNodeFuture.getOrThrow().nodeInfo
notaryNode = notaryNodeFuture.getOrThrow().nodeInfo notaryNode = notaryNodeFuture.getOrThrow().nodeInfo
newNode = { nodeName -> startNode(nodeName).getOrThrow().nodeInfo } newNode = { nodeName -> startNode(nodeName).getOrThrow().nodeInfo }
val monitor = NodeMonitorModel() val monitor = NodeMonitorModel()
stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed() stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed()
stateMachineUpdates = monitor.stateMachineUpdates.bufferUntilSubscribed() stateMachineUpdates = monitor.stateMachineUpdates.bufferUntilSubscribed()
progressTracking = monitor.progressTracking.bufferUntilSubscribed() progressTracking = monitor.progressTracking.bufferUntilSubscribed()
transactions = monitor.transactions.bufferUntilSubscribed() transactions = monitor.transactions.bufferUntilSubscribed()
vaultUpdates = monitor.vaultUpdates.bufferUntilSubscribed() vaultUpdates = monitor.vaultUpdates.bufferUntilSubscribed()
networkMapUpdates = monitor.networkMap.bufferUntilSubscribed() networkMapUpdates = monitor.networkMap.bufferUntilSubscribed()
clientToService = monitor.clientToService clientToService = monitor.clientToService
monitor.register(ArtemisMessagingComponent.toHostAndPort(aliceNode.address), configureTestSSL(), cashUser.username, cashUser.password) monitor.register(ArtemisMessagingComponent.toHostAndPort(aliceNode.address), configureTestSSL(), cashUser.username, cashUser.password)
suspend() runTest()
}
}
}
@After
fun stop() {
stopDriver()
} }
@Test @Test

View File

@ -61,8 +61,9 @@ for maintenance and other minor purposes.
These are private queues the node may use to route messages to services. The queue name ends in the base 58 encoding These are private queues the node may use to route messages to services. The queue name ends in the base 58 encoding
of the service's owning identity key. There is at most one queue per service identity (but note that any one service of the service's owning identity key. There is at most one queue per service identity (but note that any one service
may have several identities). The broker creates bridges to all nodes in the network advertising the service in may have several identities). The broker creates bridges to all nodes in the network advertising the service in
question. When a session is initiated with a service counterparty the handshake arrives on this queue, and once a question. When a session is initiated with a service counterparty the handshake is pushed onto this queue, and a
peer is picked the session continues on as normal. corresponding bridge is used to forward the message to an advertising peer's p2p queue. Once a peer is picked the
session continues on as normal.
:``internal.networkmap``: :``internal.networkmap``:
This is another private queue just for the node which functions in a similar manner to the ``internal.peers.*`` queues This is another private queue just for the node which functions in a similar manner to the ``internal.peers.*`` queues

View File

@ -12,8 +12,8 @@ import net.corda.core.serialization.OpaqueBytes
import net.corda.flows.CashCommand import net.corda.flows.CashCommand
import net.corda.flows.CashFlow import net.corda.flows.CashFlow
import net.corda.flows.CashFlowResult import net.corda.flows.CashFlowResult
import net.corda.node.driver.DriverBasedTest
import net.corda.node.driver.NodeHandle import net.corda.node.driver.NodeHandle
import net.corda.node.driver.callSuspendResume
import net.corda.node.driver.driver import net.corda.node.driver.driver
import net.corda.node.services.config.configureTestSSL import net.corda.node.services.config.configureTestSSL
import net.corda.node.services.messaging.ArtemisMessagingComponent import net.corda.node.services.messaging.ArtemisMessagingComponent
@ -22,64 +22,51 @@ import net.corda.node.services.transactions.RaftValidatingNotaryService
import net.corda.testing.expect import net.corda.testing.expect
import net.corda.testing.expectEvents import net.corda.testing.expectEvents
import net.corda.testing.replicate import net.corda.testing.replicate
import org.junit.After
import org.junit.Before
import org.junit.Test import org.junit.Test
import rx.Observable import rx.Observable
import java.util.* import java.util.*
import kotlin.test.assertEquals import kotlin.test.assertEquals
class RaftValidatingNotaryServiceTests { class RaftValidatingNotaryServiceTests : DriverBasedTest() {
lateinit var stopDriver: () -> Unit
lateinit var alice: NodeInfo lateinit var alice: NodeInfo
lateinit var notaries: List<NodeHandle> lateinit var notaries: List<NodeHandle>
lateinit var aliceProxy: CordaRPCOps lateinit var aliceProxy: CordaRPCOps
lateinit var raftNotaryIdentity: Party lateinit var raftNotaryIdentity: Party
lateinit var notaryStateMachines: Observable<Pair<NodeInfo, StateMachineUpdate>> lateinit var notaryStateMachines: Observable<Pair<NodeInfo, StateMachineUpdate>>
@Before override fun setup() = driver {
fun start() { // Start Alice and 3 raft notaries
stopDriver = callSuspendResume { suspend -> val clusterSize = 3
driver { val testUser = User("test", "test", permissions = setOf(startFlowPermission<CashFlow>()))
// Start Alice and 3 raft notaries val aliceFuture = startNode("Alice", rpcUsers = listOf(testUser))
val clusterSize = 3 val notariesFuture = startNotaryCluster(
val testUser = User("test", "test", permissions = setOf(startFlowPermission<CashFlow>())) "Notary",
val aliceFuture = startNode("Alice", rpcUsers = listOf(testUser)) rpcUsers = listOf(testUser),
val notariesFuture = startNotaryCluster( clusterSize = clusterSize,
"Notary", type = RaftValidatingNotaryService.type
rpcUsers = listOf(testUser), )
clusterSize = clusterSize,
type = RaftValidatingNotaryService.type
)
alice = aliceFuture.get().nodeInfo alice = aliceFuture.get().nodeInfo
val (notaryIdentity, notaryNodes) = notariesFuture.get() val (notaryIdentity, notaryNodes) = notariesFuture.get()
raftNotaryIdentity = notaryIdentity raftNotaryIdentity = notaryIdentity
notaries = notaryNodes notaries = notaryNodes
assertEquals(notaries.size, clusterSize) assertEquals(notaries.size, clusterSize)
assertEquals(notaries.size, notaries.map { it.nodeInfo.legalIdentity }.toSet().size) assertEquals(notaries.size, notaries.map { it.nodeInfo.legalIdentity }.toSet().size)
// Connect to Alice and the notaries // Connect to Alice and the notaries
fun connectRpc(node: NodeInfo): CordaRPCOps { fun connectRpc(node: NodeInfo): CordaRPCOps {
val client = CordaRPCClient(ArtemisMessagingComponent.toHostAndPort(node.address), configureTestSSL()) val client = CordaRPCClient(ArtemisMessagingComponent.toHostAndPort(node.address), configureTestSSL())
client.start("test", "test") client.start("test", "test")
return client.proxy() return client.proxy()
}
aliceProxy = connectRpc(alice)
val notaryProxies = notaries.map { connectRpc(it.nodeInfo) }
notaryStateMachines = Observable.from(notaryProxies.map { proxy ->
proxy.stateMachinesAndUpdates().second.map { Pair(proxy.nodeIdentity(), it) }
}).flatMap { it.onErrorResumeNext(Observable.empty()) }.bufferUntilSubscribed()
suspend()
}
} }
} aliceProxy = connectRpc(alice)
val notaryProxies = notaries.map { connectRpc(it.nodeInfo) }
notaryStateMachines = Observable.from(notaryProxies.map { proxy ->
proxy.stateMachinesAndUpdates().second.map { Pair(proxy.nodeIdentity(), it) }
}).flatMap { it.onErrorResumeNext(Observable.empty()) }.bufferUntilSubscribed()
@After runTest()
fun stop() {
stopDriver()
} }
@Test @Test

View File

@ -163,40 +163,6 @@ fun <A> driver(
dsl = dsl dsl = dsl
) )
/**
* Executes the passed in closure in a new thread, providing a function that suspends the closure, passing control back
* to the caller's context. The returned function may be used to then resume the closure.
*
* This can be used in conjunction with the driver to create @Before/@After blocks that start/shutdown the driver:
*
* val stopDriver = callSuspendResume { suspend ->
* driver(someOption = someValue) {
* .. initialise some test variables ..
* suspend()
* }
* }
* .. do tests ..
* stopDriver()
*/
fun <C> callSuspendResume(closure: (suspend: () -> Unit) -> C): () -> C {
val suspendLatch = CountDownLatch(1)
val resumeLatch = CountDownLatch(1)
val returnFuture = CompletableFuture<C>()
thread {
returnFuture.complete(
closure {
suspendLatch.countDown()
resumeLatch.await()
}
)
}
suspendLatch.await()
return {
resumeLatch.countDown()
returnFuture.get()
}
}
/** /**
* This is a helper method to allow extending of the DSL, along the lines of * This is a helper method to allow extending of the DSL, along the lines of
* interface SomeOtherExposedDSLInterface : DriverDSLExposedInterface * interface SomeOtherExposedDSLInterface : DriverDSLExposedInterface

View File

@ -0,0 +1,39 @@
package net.corda.node.driver
import org.junit.After
import org.junit.Before
import java.util.concurrent.CountDownLatch
import kotlin.concurrent.thread
abstract class DriverBasedTest {
private val stopDriver = CountDownLatch(1)
private var driverThread: Thread? = null
private lateinit var driverStarted: CountDownLatch
protected sealed class RunTestToken {
internal object Token : RunTestToken()
}
protected abstract fun setup(): RunTestToken
protected fun DriverDSLExposedInterface.runTest(): RunTestToken {
driverStarted.countDown()
stopDriver.await()
return RunTestToken.Token
}
@Before
fun start() {
driverStarted = CountDownLatch(1)
driverThread = thread {
setup()
}
driverStarted.await()
}
@After
fun stop() {
stopDriver.countDown()
driverThread?.join()
}
}

View File

@ -133,21 +133,21 @@ class ArtemisMessagingServer(override val config: NodeConfiguration,
} }
val addressesToCreateBridgesTo = HashSet<ArtemisPeerAddress>() val addressesToCreateBridgesTo = HashSet<ArtemisPeerAddress>()
val addressesToRemoveBridgesTo = HashSet<ArtemisPeerAddress>() val addressesToRemoveBridgesFrom = HashSet<ArtemisPeerAddress>()
when (change) { when (change) {
is MapChange.Modified -> { is MapChange.Modified -> {
addAddresses(change.node, addressesToCreateBridgesTo) addAddresses(change.node, addressesToCreateBridgesTo)
addAddresses(change.previousNode, addressesToRemoveBridgesTo) addAddresses(change.previousNode, addressesToRemoveBridgesFrom)
} }
is MapChange.Removed -> { is MapChange.Removed -> {
addAddresses(change.node, addressesToRemoveBridgesTo) addAddresses(change.node, addressesToRemoveBridgesFrom)
} }
is MapChange.Added -> { is MapChange.Added -> {
addAddresses(change.node, addressesToCreateBridgesTo) addAddresses(change.node, addressesToCreateBridgesTo)
} }
} }
(addressesToRemoveBridgesTo - addressesToCreateBridgesTo).forEach { (addressesToRemoveBridgesFrom - addressesToCreateBridgesTo).forEach {
maybeDestroyBridge(bridgeNameForAddress(it)) maybeDestroyBridge(bridgeNameForAddress(it))
} }
addressesToCreateBridgesTo.forEach { addressesToCreateBridgesTo.forEach {

View File

@ -169,14 +169,14 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable @Suspendable
private inline fun <reified M : SessionMessage> receiveInternal(session: FlowSession): M { private inline fun <reified M : SessionMessage> receiveInternal(session: FlowSession): M {
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)).second return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)).message
} }
private inline fun <reified M : SessionMessage> sendAndReceiveInternal(session: FlowSession, message: SessionMessage): M { private inline fun <reified M : SessionMessage> sendAndReceiveInternal(session: FlowSession, message: SessionMessage): M {
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)).second return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)).message
} }
private inline fun <reified M : SessionMessage> sendAndReceiveInternalWithParty(session: FlowSession, message: SessionMessage): Pair<Party, M> { private inline fun <reified M : SessionMessage> sendAndReceiveInternalWithParty(session: FlowSession, message: SessionMessage): ReceivedSessionMessage<M> {
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)) return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java))
} }
@ -215,8 +215,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} }
@Suspendable @Suspendable
private fun <M : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): Pair<Party, M> { private fun <M : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
fun getReceivedMessage(): Pair<Party, ExistingSessionMessage>? = receiveRequest.session.receivedMessages.poll() fun getReceivedMessage(): ReceivedSessionMessage<ExistingSessionMessage>? = receiveRequest.session.receivedMessages.poll()
val polledMessage = getReceivedMessage() val polledMessage = getReceivedMessage()
val receivedMessage = if (polledMessage != null) { val receivedMessage = if (polledMessage != null) {
@ -232,11 +232,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest") ?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest")
} }
if (receivedMessage.second is SessionEnd) { if (receivedMessage.message is SessionEnd) {
openSessions.values.remove(receiveRequest.session) openSessions.values.remove(receiveRequest.session)
throw FlowSessionException("Counterparty on ${receiveRequest.session.state.sendToParty} has prematurely ended on $receiveRequest") throw FlowSessionException("Counterparty on ${receiveRequest.session.state.sendToParty} has prematurely ended on $receiveRequest")
} else if (receiveRequest.receiveType.isInstance(receivedMessage.second)) { } else if (receiveRequest.receiveType.isInstance(receivedMessage.message)) {
return Pair(receivedMessage.first, receiveRequest.receiveType.cast(receivedMessage.second)) return ReceivedSessionMessage(receivedMessage.sendingParty, receiveRequest.receiveType.cast(receivedMessage.message))
} else { } else {
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $receiveRequest") throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $receiveRequest")
} }

View File

@ -244,7 +244,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
if (message is SessionEnd) { if (message is SessionEnd) {
openSessions.remove(message.recipientSessionId) openSessions.remove(message.recipientSessionId)
} }
session.receivedMessages += Pair(otherParty, message) session.receivedMessages += ReceivedSessionMessage(otherParty, message)
if (session.waitingForResponse) { if (session.waitingForResponse) {
// We only want to resume once, so immediately reset the flag. // We only want to resume once, so immediately reset the flag.
session.waitingForResponse = false session.waitingForResponse = false
@ -277,7 +277,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val psm = createFiber(flow) val psm = createFiber(flow)
val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(otherParty, otherPartySessionId)) val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(otherParty, otherPartySessionId))
if (sessionInit.firstPayload != null) { if (sessionInit.firstPayload != null) {
session.receivedMessages += Pair(otherParty, SessionData(session.ourSessionId, sessionInit.firstPayload)) session.receivedMessages += ReceivedSessionMessage(otherParty, SessionData(session.ourSessionId, sessionInit.firstPayload))
} }
openSessions[session.ourSessionId] = session openSessions[session.ourSessionId] = session
psm.openSessions[Pair(flow, otherParty)] = session psm.openSessions[Pair(flow, otherParty)] = session
@ -453,6 +453,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
serviceHub.networkService.send(sessionTopic, message, address) serviceHub.networkService.send(sessionTopic, message, address)
} }
data class ReceivedSessionMessage<out M : SessionMessage>(val sendingParty: Party, val message: M)
interface SessionMessage interface SessionMessage
@ -509,7 +510,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
var state: FlowSessionState, var state: FlowSessionState,
@Volatile var waitingForResponse: Boolean = false @Volatile var waitingForResponse: Boolean = false
) { ) {
val receivedMessages = ConcurrentLinkedQueue<Pair<Party, ExistingSessionMessage>>() val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<ExistingSessionMessage>>()
val psm: FlowStateMachineImpl<*> get() = flow.fsm as FlowStateMachineImpl<*> val psm: FlowStateMachineImpl<*> get() = flow.fsm as FlowStateMachineImpl<*>
} }