Add ReceivedSessionMessage, DriverBasedTest re #57

This commit is contained in:
exfalso
2016-12-15 11:35:24 +00:00
parent 8ea4c258f1
commit 53bbb57345
9 changed files with 113 additions and 147 deletions

View File

@ -12,8 +12,8 @@ import net.corda.core.serialization.OpaqueBytes
import net.corda.flows.CashCommand
import net.corda.flows.CashFlow
import net.corda.flows.CashFlowResult
import net.corda.node.driver.DriverBasedTest
import net.corda.node.driver.NodeHandle
import net.corda.node.driver.callSuspendResume
import net.corda.node.driver.driver
import net.corda.node.services.config.configureTestSSL
import net.corda.node.services.messaging.ArtemisMessagingComponent
@ -22,64 +22,51 @@ import net.corda.node.services.transactions.RaftValidatingNotaryService
import net.corda.testing.expect
import net.corda.testing.expectEvents
import net.corda.testing.replicate
import org.junit.After
import org.junit.Before
import org.junit.Test
import rx.Observable
import java.util.*
import kotlin.test.assertEquals
class RaftValidatingNotaryServiceTests {
lateinit var stopDriver: () -> Unit
class RaftValidatingNotaryServiceTests : DriverBasedTest() {
lateinit var alice: NodeInfo
lateinit var notaries: List<NodeHandle>
lateinit var aliceProxy: CordaRPCOps
lateinit var raftNotaryIdentity: Party
lateinit var notaryStateMachines: Observable<Pair<NodeInfo, StateMachineUpdate>>
@Before
fun start() {
stopDriver = callSuspendResume { suspend ->
driver {
// Start Alice and 3 raft notaries
val clusterSize = 3
val testUser = User("test", "test", permissions = setOf(startFlowPermission<CashFlow>()))
val aliceFuture = startNode("Alice", rpcUsers = listOf(testUser))
val notariesFuture = startNotaryCluster(
"Notary",
rpcUsers = listOf(testUser),
clusterSize = clusterSize,
type = RaftValidatingNotaryService.type
)
override fun setup() = driver {
// Start Alice and 3 raft notaries
val clusterSize = 3
val testUser = User("test", "test", permissions = setOf(startFlowPermission<CashFlow>()))
val aliceFuture = startNode("Alice", rpcUsers = listOf(testUser))
val notariesFuture = startNotaryCluster(
"Notary",
rpcUsers = listOf(testUser),
clusterSize = clusterSize,
type = RaftValidatingNotaryService.type
)
alice = aliceFuture.get().nodeInfo
val (notaryIdentity, notaryNodes) = notariesFuture.get()
raftNotaryIdentity = notaryIdentity
notaries = notaryNodes
alice = aliceFuture.get().nodeInfo
val (notaryIdentity, notaryNodes) = notariesFuture.get()
raftNotaryIdentity = notaryIdentity
notaries = notaryNodes
assertEquals(notaries.size, clusterSize)
assertEquals(notaries.size, notaries.map { it.nodeInfo.legalIdentity }.toSet().size)
assertEquals(notaries.size, clusterSize)
assertEquals(notaries.size, notaries.map { it.nodeInfo.legalIdentity }.toSet().size)
// Connect to Alice and the notaries
fun connectRpc(node: NodeInfo): CordaRPCOps {
val client = CordaRPCClient(ArtemisMessagingComponent.toHostAndPort(node.address), configureTestSSL())
client.start("test", "test")
return client.proxy()
}
aliceProxy = connectRpc(alice)
val notaryProxies = notaries.map { connectRpc(it.nodeInfo) }
notaryStateMachines = Observable.from(notaryProxies.map { proxy ->
proxy.stateMachinesAndUpdates().second.map { Pair(proxy.nodeIdentity(), it) }
}).flatMap { it.onErrorResumeNext(Observable.empty()) }.bufferUntilSubscribed()
suspend()
}
// Connect to Alice and the notaries
fun connectRpc(node: NodeInfo): CordaRPCOps {
val client = CordaRPCClient(ArtemisMessagingComponent.toHostAndPort(node.address), configureTestSSL())
client.start("test", "test")
return client.proxy()
}
}
aliceProxy = connectRpc(alice)
val notaryProxies = notaries.map { connectRpc(it.nodeInfo) }
notaryStateMachines = Observable.from(notaryProxies.map { proxy ->
proxy.stateMachinesAndUpdates().second.map { Pair(proxy.nodeIdentity(), it) }
}).flatMap { it.onErrorResumeNext(Observable.empty()) }.bufferUntilSubscribed()
@After
fun stop() {
stopDriver()
runTest()
}
@Test

View File

@ -163,40 +163,6 @@ fun <A> driver(
dsl = dsl
)
/**
* Executes the passed in closure in a new thread, providing a function that suspends the closure, passing control back
* to the caller's context. The returned function may be used to then resume the closure.
*
* This can be used in conjunction with the driver to create @Before/@After blocks that start/shutdown the driver:
*
* val stopDriver = callSuspendResume { suspend ->
* driver(someOption = someValue) {
* .. initialise some test variables ..
* suspend()
* }
* }
* .. do tests ..
* stopDriver()
*/
fun <C> callSuspendResume(closure: (suspend: () -> Unit) -> C): () -> C {
val suspendLatch = CountDownLatch(1)
val resumeLatch = CountDownLatch(1)
val returnFuture = CompletableFuture<C>()
thread {
returnFuture.complete(
closure {
suspendLatch.countDown()
resumeLatch.await()
}
)
}
suspendLatch.await()
return {
resumeLatch.countDown()
returnFuture.get()
}
}
/**
* This is a helper method to allow extending of the DSL, along the lines of
* interface SomeOtherExposedDSLInterface : DriverDSLExposedInterface

View File

@ -0,0 +1,39 @@
package net.corda.node.driver
import org.junit.After
import org.junit.Before
import java.util.concurrent.CountDownLatch
import kotlin.concurrent.thread
abstract class DriverBasedTest {
private val stopDriver = CountDownLatch(1)
private var driverThread: Thread? = null
private lateinit var driverStarted: CountDownLatch
protected sealed class RunTestToken {
internal object Token : RunTestToken()
}
protected abstract fun setup(): RunTestToken
protected fun DriverDSLExposedInterface.runTest(): RunTestToken {
driverStarted.countDown()
stopDriver.await()
return RunTestToken.Token
}
@Before
fun start() {
driverStarted = CountDownLatch(1)
driverThread = thread {
setup()
}
driverStarted.await()
}
@After
fun stop() {
stopDriver.countDown()
driverThread?.join()
}
}

View File

@ -133,21 +133,21 @@ class ArtemisMessagingServer(override val config: NodeConfiguration,
}
val addressesToCreateBridgesTo = HashSet<ArtemisPeerAddress>()
val addressesToRemoveBridgesTo = HashSet<ArtemisPeerAddress>()
val addressesToRemoveBridgesFrom = HashSet<ArtemisPeerAddress>()
when (change) {
is MapChange.Modified -> {
addAddresses(change.node, addressesToCreateBridgesTo)
addAddresses(change.previousNode, addressesToRemoveBridgesTo)
addAddresses(change.previousNode, addressesToRemoveBridgesFrom)
}
is MapChange.Removed -> {
addAddresses(change.node, addressesToRemoveBridgesTo)
addAddresses(change.node, addressesToRemoveBridgesFrom)
}
is MapChange.Added -> {
addAddresses(change.node, addressesToCreateBridgesTo)
}
}
(addressesToRemoveBridgesTo - addressesToCreateBridgesTo).forEach {
(addressesToRemoveBridgesFrom - addressesToCreateBridgesTo).forEach {
maybeDestroyBridge(bridgeNameForAddress(it))
}
addressesToCreateBridgesTo.forEach {

View File

@ -169,14 +169,14 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable
private inline fun <reified M : SessionMessage> receiveInternal(session: FlowSession): M {
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)).second
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)).message
}
private inline fun <reified M : SessionMessage> sendAndReceiveInternal(session: FlowSession, message: SessionMessage): M {
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)).second
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)).message
}
private inline fun <reified M : SessionMessage> sendAndReceiveInternalWithParty(session: FlowSession, message: SessionMessage): Pair<Party, M> {
private inline fun <reified M : SessionMessage> sendAndReceiveInternalWithParty(session: FlowSession, message: SessionMessage): ReceivedSessionMessage<M> {
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java))
}
@ -215,8 +215,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
}
@Suspendable
private fun <M : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): Pair<Party, M> {
fun getReceivedMessage(): Pair<Party, ExistingSessionMessage>? = receiveRequest.session.receivedMessages.poll()
private fun <M : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
fun getReceivedMessage(): ReceivedSessionMessage<ExistingSessionMessage>? = receiveRequest.session.receivedMessages.poll()
val polledMessage = getReceivedMessage()
val receivedMessage = if (polledMessage != null) {
@ -232,11 +232,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest")
}
if (receivedMessage.second is SessionEnd) {
if (receivedMessage.message is SessionEnd) {
openSessions.values.remove(receiveRequest.session)
throw FlowSessionException("Counterparty on ${receiveRequest.session.state.sendToParty} has prematurely ended on $receiveRequest")
} else if (receiveRequest.receiveType.isInstance(receivedMessage.second)) {
return Pair(receivedMessage.first, receiveRequest.receiveType.cast(receivedMessage.second))
} else if (receiveRequest.receiveType.isInstance(receivedMessage.message)) {
return ReceivedSessionMessage(receivedMessage.sendingParty, receiveRequest.receiveType.cast(receivedMessage.message))
} else {
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $receiveRequest")
}

View File

@ -244,7 +244,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
if (message is SessionEnd) {
openSessions.remove(message.recipientSessionId)
}
session.receivedMessages += Pair(otherParty, message)
session.receivedMessages += ReceivedSessionMessage(otherParty, message)
if (session.waitingForResponse) {
// We only want to resume once, so immediately reset the flag.
session.waitingForResponse = false
@ -277,7 +277,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val psm = createFiber(flow)
val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(otherParty, otherPartySessionId))
if (sessionInit.firstPayload != null) {
session.receivedMessages += Pair(otherParty, SessionData(session.ourSessionId, sessionInit.firstPayload))
session.receivedMessages += ReceivedSessionMessage(otherParty, SessionData(session.ourSessionId, sessionInit.firstPayload))
}
openSessions[session.ourSessionId] = session
psm.openSessions[Pair(flow, otherParty)] = session
@ -453,6 +453,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
serviceHub.networkService.send(sessionTopic, message, address)
}
data class ReceivedSessionMessage<out M : SessionMessage>(val sendingParty: Party, val message: M)
interface SessionMessage
@ -509,7 +510,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
var state: FlowSessionState,
@Volatile var waitingForResponse: Boolean = false
) {
val receivedMessages = ConcurrentLinkedQueue<Pair<Party, ExistingSessionMessage>>()
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<ExistingSessionMessage>>()
val psm: FlowStateMachineImpl<*> get() = flow.fsm as FlowStateMachineImpl<*>
}