mirror of
https://github.com/corda/corda.git
synced 2025-06-18 15:18:16 +00:00
Add ReceivedSessionMessage, DriverBasedTest re #57
This commit is contained in:
@ -8,6 +8,7 @@ import net.corda.core.random63BitValue
|
|||||||
import net.corda.core.serialization.OpaqueBytes
|
import net.corda.core.serialization.OpaqueBytes
|
||||||
import net.corda.flows.CashCommand
|
import net.corda.flows.CashCommand
|
||||||
import net.corda.flows.CashFlow
|
import net.corda.flows.CashFlow
|
||||||
|
import net.corda.node.driver.DriverBasedTest
|
||||||
import net.corda.node.driver.NodeHandle
|
import net.corda.node.driver.NodeHandle
|
||||||
import net.corda.node.driver.driver
|
import net.corda.node.driver.driver
|
||||||
import net.corda.node.services.User
|
import net.corda.node.services.User
|
||||||
@ -24,32 +25,16 @@ import org.junit.Test
|
|||||||
import java.util.concurrent.CountDownLatch
|
import java.util.concurrent.CountDownLatch
|
||||||
import kotlin.concurrent.thread
|
import kotlin.concurrent.thread
|
||||||
|
|
||||||
class CordaRPCClientTest {
|
class CordaRPCClientTest : DriverBasedTest() {
|
||||||
|
|
||||||
private val rpcUser = User("user1", "test", permissions = setOf(startFlowPermission<CashFlow>()))
|
private val rpcUser = User("user1", "test", permissions = setOf(startFlowPermission<CashFlow>()))
|
||||||
private val stopDriver = CountDownLatch(1)
|
|
||||||
private var driverThread: Thread? = null
|
|
||||||
private lateinit var client: CordaRPCClient
|
private lateinit var client: CordaRPCClient
|
||||||
private lateinit var driverInfo: NodeHandle
|
private lateinit var driverInfo: NodeHandle
|
||||||
|
|
||||||
@Before
|
override fun setup() = driver(isDebug = true) {
|
||||||
fun start() {
|
driverInfo = startNode(rpcUsers = listOf(rpcUser), advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type))).getOrThrow()
|
||||||
val driverStarted = CountDownLatch(1)
|
client = CordaRPCClient(toHostAndPort(driverInfo.nodeInfo.address), configureTestSSL())
|
||||||
driverThread = thread {
|
runTest()
|
||||||
driver(isDebug = true) {
|
|
||||||
driverInfo = startNode(rpcUsers = listOf(rpcUser), advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type))).getOrThrow()
|
|
||||||
client = CordaRPCClient(toHostAndPort(driverInfo.nodeInfo.address), configureTestSSL())
|
|
||||||
driverStarted.countDown()
|
|
||||||
stopDriver.await()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
driverStarted.await()
|
|
||||||
}
|
|
||||||
|
|
||||||
@After
|
|
||||||
fun stop() {
|
|
||||||
stopDriver.countDown()
|
|
||||||
driverThread?.join()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -19,7 +19,7 @@ import net.corda.core.serialization.OpaqueBytes
|
|||||||
import net.corda.core.transactions.SignedTransaction
|
import net.corda.core.transactions.SignedTransaction
|
||||||
import net.corda.flows.CashCommand
|
import net.corda.flows.CashCommand
|
||||||
import net.corda.flows.CashFlow
|
import net.corda.flows.CashFlow
|
||||||
import net.corda.node.driver.callSuspendResume
|
import net.corda.node.driver.DriverBasedTest
|
||||||
import net.corda.node.driver.driver
|
import net.corda.node.driver.driver
|
||||||
import net.corda.node.services.User
|
import net.corda.node.services.User
|
||||||
import net.corda.node.services.config.configureTestSSL
|
import net.corda.node.services.config.configureTestSSL
|
||||||
@ -30,16 +30,13 @@ import net.corda.node.services.transactions.SimpleNotaryService
|
|||||||
import net.corda.testing.expect
|
import net.corda.testing.expect
|
||||||
import net.corda.testing.expectEvents
|
import net.corda.testing.expectEvents
|
||||||
import net.corda.testing.sequence
|
import net.corda.testing.sequence
|
||||||
import org.junit.After
|
|
||||||
import org.junit.Before
|
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import rx.Observable
|
import rx.Observable
|
||||||
import rx.Observer
|
import rx.Observer
|
||||||
|
|
||||||
class NodeMonitorModelTest {
|
class NodeMonitorModelTest : DriverBasedTest() {
|
||||||
lateinit var aliceNode: NodeInfo
|
lateinit var aliceNode: NodeInfo
|
||||||
lateinit var notaryNode: NodeInfo
|
lateinit var notaryNode: NodeInfo
|
||||||
lateinit var stopDriver: () -> Unit
|
|
||||||
|
|
||||||
lateinit var stateMachineTransactionMapping: Observable<StateMachineTransactionMapping>
|
lateinit var stateMachineTransactionMapping: Observable<StateMachineTransactionMapping>
|
||||||
lateinit var stateMachineUpdates: Observable<StateMachineUpdate>
|
lateinit var stateMachineUpdates: Observable<StateMachineUpdate>
|
||||||
@ -50,36 +47,26 @@ class NodeMonitorModelTest {
|
|||||||
lateinit var clientToService: Observer<CashCommand>
|
lateinit var clientToService: Observer<CashCommand>
|
||||||
lateinit var newNode: (String) -> NodeInfo
|
lateinit var newNode: (String) -> NodeInfo
|
||||||
|
|
||||||
@Before
|
override fun setup() = driver {
|
||||||
fun start() {
|
val cashUser = User("user1", "test", permissions = setOf(startFlowPermission<CashFlow>()))
|
||||||
stopDriver = callSuspendResume { suspend ->
|
val aliceNodeFuture = startNode("Alice", rpcUsers = listOf(cashUser))
|
||||||
driver {
|
val notaryNodeFuture = startNode("Notary", advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type)))
|
||||||
val cashUser = User("user1", "test", permissions = setOf(startFlowPermission<CashFlow>()))
|
|
||||||
val aliceNodeFuture = startNode("Alice", rpcUsers = listOf(cashUser))
|
|
||||||
val notaryNodeFuture = startNode("Notary", advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type)))
|
|
||||||
|
|
||||||
aliceNode = aliceNodeFuture.getOrThrow().nodeInfo
|
aliceNode = aliceNodeFuture.getOrThrow().nodeInfo
|
||||||
notaryNode = notaryNodeFuture.getOrThrow().nodeInfo
|
notaryNode = notaryNodeFuture.getOrThrow().nodeInfo
|
||||||
newNode = { nodeName -> startNode(nodeName).getOrThrow().nodeInfo }
|
newNode = { nodeName -> startNode(nodeName).getOrThrow().nodeInfo }
|
||||||
val monitor = NodeMonitorModel()
|
val monitor = NodeMonitorModel()
|
||||||
|
|
||||||
stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed()
|
stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed()
|
||||||
stateMachineUpdates = monitor.stateMachineUpdates.bufferUntilSubscribed()
|
stateMachineUpdates = monitor.stateMachineUpdates.bufferUntilSubscribed()
|
||||||
progressTracking = monitor.progressTracking.bufferUntilSubscribed()
|
progressTracking = monitor.progressTracking.bufferUntilSubscribed()
|
||||||
transactions = monitor.transactions.bufferUntilSubscribed()
|
transactions = monitor.transactions.bufferUntilSubscribed()
|
||||||
vaultUpdates = monitor.vaultUpdates.bufferUntilSubscribed()
|
vaultUpdates = monitor.vaultUpdates.bufferUntilSubscribed()
|
||||||
networkMapUpdates = monitor.networkMap.bufferUntilSubscribed()
|
networkMapUpdates = monitor.networkMap.bufferUntilSubscribed()
|
||||||
clientToService = monitor.clientToService
|
clientToService = monitor.clientToService
|
||||||
|
|
||||||
monitor.register(ArtemisMessagingComponent.toHostAndPort(aliceNode.address), configureTestSSL(), cashUser.username, cashUser.password)
|
monitor.register(ArtemisMessagingComponent.toHostAndPort(aliceNode.address), configureTestSSL(), cashUser.username, cashUser.password)
|
||||||
suspend()
|
runTest()
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@After
|
|
||||||
fun stop() {
|
|
||||||
stopDriver()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -61,8 +61,9 @@ for maintenance and other minor purposes.
|
|||||||
These are private queues the node may use to route messages to services. The queue name ends in the base 58 encoding
|
These are private queues the node may use to route messages to services. The queue name ends in the base 58 encoding
|
||||||
of the service's owning identity key. There is at most one queue per service identity (but note that any one service
|
of the service's owning identity key. There is at most one queue per service identity (but note that any one service
|
||||||
may have several identities). The broker creates bridges to all nodes in the network advertising the service in
|
may have several identities). The broker creates bridges to all nodes in the network advertising the service in
|
||||||
question. When a session is initiated with a service counterparty the handshake arrives on this queue, and once a
|
question. When a session is initiated with a service counterparty the handshake is pushed onto this queue, and a
|
||||||
peer is picked the session continues on as normal.
|
corresponding bridge is used to forward the message to an advertising peer's p2p queue. Once a peer is picked the
|
||||||
|
session continues on as normal.
|
||||||
|
|
||||||
:``internal.networkmap``:
|
:``internal.networkmap``:
|
||||||
This is another private queue just for the node which functions in a similar manner to the ``internal.peers.*`` queues
|
This is another private queue just for the node which functions in a similar manner to the ``internal.peers.*`` queues
|
||||||
|
@ -12,8 +12,8 @@ import net.corda.core.serialization.OpaqueBytes
|
|||||||
import net.corda.flows.CashCommand
|
import net.corda.flows.CashCommand
|
||||||
import net.corda.flows.CashFlow
|
import net.corda.flows.CashFlow
|
||||||
import net.corda.flows.CashFlowResult
|
import net.corda.flows.CashFlowResult
|
||||||
|
import net.corda.node.driver.DriverBasedTest
|
||||||
import net.corda.node.driver.NodeHandle
|
import net.corda.node.driver.NodeHandle
|
||||||
import net.corda.node.driver.callSuspendResume
|
|
||||||
import net.corda.node.driver.driver
|
import net.corda.node.driver.driver
|
||||||
import net.corda.node.services.config.configureTestSSL
|
import net.corda.node.services.config.configureTestSSL
|
||||||
import net.corda.node.services.messaging.ArtemisMessagingComponent
|
import net.corda.node.services.messaging.ArtemisMessagingComponent
|
||||||
@ -22,64 +22,51 @@ import net.corda.node.services.transactions.RaftValidatingNotaryService
|
|||||||
import net.corda.testing.expect
|
import net.corda.testing.expect
|
||||||
import net.corda.testing.expectEvents
|
import net.corda.testing.expectEvents
|
||||||
import net.corda.testing.replicate
|
import net.corda.testing.replicate
|
||||||
import org.junit.After
|
|
||||||
import org.junit.Before
|
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import rx.Observable
|
import rx.Observable
|
||||||
import java.util.*
|
import java.util.*
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
class RaftValidatingNotaryServiceTests {
|
class RaftValidatingNotaryServiceTests : DriverBasedTest() {
|
||||||
lateinit var stopDriver: () -> Unit
|
|
||||||
lateinit var alice: NodeInfo
|
lateinit var alice: NodeInfo
|
||||||
lateinit var notaries: List<NodeHandle>
|
lateinit var notaries: List<NodeHandle>
|
||||||
lateinit var aliceProxy: CordaRPCOps
|
lateinit var aliceProxy: CordaRPCOps
|
||||||
lateinit var raftNotaryIdentity: Party
|
lateinit var raftNotaryIdentity: Party
|
||||||
lateinit var notaryStateMachines: Observable<Pair<NodeInfo, StateMachineUpdate>>
|
lateinit var notaryStateMachines: Observable<Pair<NodeInfo, StateMachineUpdate>>
|
||||||
|
|
||||||
@Before
|
override fun setup() = driver {
|
||||||
fun start() {
|
// Start Alice and 3 raft notaries
|
||||||
stopDriver = callSuspendResume { suspend ->
|
val clusterSize = 3
|
||||||
driver {
|
val testUser = User("test", "test", permissions = setOf(startFlowPermission<CashFlow>()))
|
||||||
// Start Alice and 3 raft notaries
|
val aliceFuture = startNode("Alice", rpcUsers = listOf(testUser))
|
||||||
val clusterSize = 3
|
val notariesFuture = startNotaryCluster(
|
||||||
val testUser = User("test", "test", permissions = setOf(startFlowPermission<CashFlow>()))
|
"Notary",
|
||||||
val aliceFuture = startNode("Alice", rpcUsers = listOf(testUser))
|
rpcUsers = listOf(testUser),
|
||||||
val notariesFuture = startNotaryCluster(
|
clusterSize = clusterSize,
|
||||||
"Notary",
|
type = RaftValidatingNotaryService.type
|
||||||
rpcUsers = listOf(testUser),
|
)
|
||||||
clusterSize = clusterSize,
|
|
||||||
type = RaftValidatingNotaryService.type
|
|
||||||
)
|
|
||||||
|
|
||||||
alice = aliceFuture.get().nodeInfo
|
alice = aliceFuture.get().nodeInfo
|
||||||
val (notaryIdentity, notaryNodes) = notariesFuture.get()
|
val (notaryIdentity, notaryNodes) = notariesFuture.get()
|
||||||
raftNotaryIdentity = notaryIdentity
|
raftNotaryIdentity = notaryIdentity
|
||||||
notaries = notaryNodes
|
notaries = notaryNodes
|
||||||
|
|
||||||
assertEquals(notaries.size, clusterSize)
|
assertEquals(notaries.size, clusterSize)
|
||||||
assertEquals(notaries.size, notaries.map { it.nodeInfo.legalIdentity }.toSet().size)
|
assertEquals(notaries.size, notaries.map { it.nodeInfo.legalIdentity }.toSet().size)
|
||||||
|
|
||||||
// Connect to Alice and the notaries
|
// Connect to Alice and the notaries
|
||||||
fun connectRpc(node: NodeInfo): CordaRPCOps {
|
fun connectRpc(node: NodeInfo): CordaRPCOps {
|
||||||
val client = CordaRPCClient(ArtemisMessagingComponent.toHostAndPort(node.address), configureTestSSL())
|
val client = CordaRPCClient(ArtemisMessagingComponent.toHostAndPort(node.address), configureTestSSL())
|
||||||
client.start("test", "test")
|
client.start("test", "test")
|
||||||
return client.proxy()
|
return client.proxy()
|
||||||
}
|
|
||||||
aliceProxy = connectRpc(alice)
|
|
||||||
val notaryProxies = notaries.map { connectRpc(it.nodeInfo) }
|
|
||||||
notaryStateMachines = Observable.from(notaryProxies.map { proxy ->
|
|
||||||
proxy.stateMachinesAndUpdates().second.map { Pair(proxy.nodeIdentity(), it) }
|
|
||||||
}).flatMap { it.onErrorResumeNext(Observable.empty()) }.bufferUntilSubscribed()
|
|
||||||
|
|
||||||
suspend()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
aliceProxy = connectRpc(alice)
|
||||||
|
val notaryProxies = notaries.map { connectRpc(it.nodeInfo) }
|
||||||
|
notaryStateMachines = Observable.from(notaryProxies.map { proxy ->
|
||||||
|
proxy.stateMachinesAndUpdates().second.map { Pair(proxy.nodeIdentity(), it) }
|
||||||
|
}).flatMap { it.onErrorResumeNext(Observable.empty()) }.bufferUntilSubscribed()
|
||||||
|
|
||||||
@After
|
runTest()
|
||||||
fun stop() {
|
|
||||||
stopDriver()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -163,40 +163,6 @@ fun <A> driver(
|
|||||||
dsl = dsl
|
dsl = dsl
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
|
||||||
* Executes the passed in closure in a new thread, providing a function that suspends the closure, passing control back
|
|
||||||
* to the caller's context. The returned function may be used to then resume the closure.
|
|
||||||
*
|
|
||||||
* This can be used in conjunction with the driver to create @Before/@After blocks that start/shutdown the driver:
|
|
||||||
*
|
|
||||||
* val stopDriver = callSuspendResume { suspend ->
|
|
||||||
* driver(someOption = someValue) {
|
|
||||||
* .. initialise some test variables ..
|
|
||||||
* suspend()
|
|
||||||
* }
|
|
||||||
* }
|
|
||||||
* .. do tests ..
|
|
||||||
* stopDriver()
|
|
||||||
*/
|
|
||||||
fun <C> callSuspendResume(closure: (suspend: () -> Unit) -> C): () -> C {
|
|
||||||
val suspendLatch = CountDownLatch(1)
|
|
||||||
val resumeLatch = CountDownLatch(1)
|
|
||||||
val returnFuture = CompletableFuture<C>()
|
|
||||||
thread {
|
|
||||||
returnFuture.complete(
|
|
||||||
closure {
|
|
||||||
suspendLatch.countDown()
|
|
||||||
resumeLatch.await()
|
|
||||||
}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
suspendLatch.await()
|
|
||||||
return {
|
|
||||||
resumeLatch.countDown()
|
|
||||||
returnFuture.get()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This is a helper method to allow extending of the DSL, along the lines of
|
* This is a helper method to allow extending of the DSL, along the lines of
|
||||||
* interface SomeOtherExposedDSLInterface : DriverDSLExposedInterface
|
* interface SomeOtherExposedDSLInterface : DriverDSLExposedInterface
|
||||||
|
@ -0,0 +1,39 @@
|
|||||||
|
package net.corda.node.driver
|
||||||
|
|
||||||
|
import org.junit.After
|
||||||
|
import org.junit.Before
|
||||||
|
import java.util.concurrent.CountDownLatch
|
||||||
|
import kotlin.concurrent.thread
|
||||||
|
|
||||||
|
abstract class DriverBasedTest {
|
||||||
|
private val stopDriver = CountDownLatch(1)
|
||||||
|
private var driverThread: Thread? = null
|
||||||
|
private lateinit var driverStarted: CountDownLatch
|
||||||
|
|
||||||
|
protected sealed class RunTestToken {
|
||||||
|
internal object Token : RunTestToken()
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract fun setup(): RunTestToken
|
||||||
|
|
||||||
|
protected fun DriverDSLExposedInterface.runTest(): RunTestToken {
|
||||||
|
driverStarted.countDown()
|
||||||
|
stopDriver.await()
|
||||||
|
return RunTestToken.Token
|
||||||
|
}
|
||||||
|
|
||||||
|
@Before
|
||||||
|
fun start() {
|
||||||
|
driverStarted = CountDownLatch(1)
|
||||||
|
driverThread = thread {
|
||||||
|
setup()
|
||||||
|
}
|
||||||
|
driverStarted.await()
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
fun stop() {
|
||||||
|
stopDriver.countDown()
|
||||||
|
driverThread?.join()
|
||||||
|
}
|
||||||
|
}
|
@ -133,21 +133,21 @@ class ArtemisMessagingServer(override val config: NodeConfiguration,
|
|||||||
}
|
}
|
||||||
|
|
||||||
val addressesToCreateBridgesTo = HashSet<ArtemisPeerAddress>()
|
val addressesToCreateBridgesTo = HashSet<ArtemisPeerAddress>()
|
||||||
val addressesToRemoveBridgesTo = HashSet<ArtemisPeerAddress>()
|
val addressesToRemoveBridgesFrom = HashSet<ArtemisPeerAddress>()
|
||||||
when (change) {
|
when (change) {
|
||||||
is MapChange.Modified -> {
|
is MapChange.Modified -> {
|
||||||
addAddresses(change.node, addressesToCreateBridgesTo)
|
addAddresses(change.node, addressesToCreateBridgesTo)
|
||||||
addAddresses(change.previousNode, addressesToRemoveBridgesTo)
|
addAddresses(change.previousNode, addressesToRemoveBridgesFrom)
|
||||||
}
|
}
|
||||||
is MapChange.Removed -> {
|
is MapChange.Removed -> {
|
||||||
addAddresses(change.node, addressesToRemoveBridgesTo)
|
addAddresses(change.node, addressesToRemoveBridgesFrom)
|
||||||
}
|
}
|
||||||
is MapChange.Added -> {
|
is MapChange.Added -> {
|
||||||
addAddresses(change.node, addressesToCreateBridgesTo)
|
addAddresses(change.node, addressesToCreateBridgesTo)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
(addressesToRemoveBridgesTo - addressesToCreateBridgesTo).forEach {
|
(addressesToRemoveBridgesFrom - addressesToCreateBridgesTo).forEach {
|
||||||
maybeDestroyBridge(bridgeNameForAddress(it))
|
maybeDestroyBridge(bridgeNameForAddress(it))
|
||||||
}
|
}
|
||||||
addressesToCreateBridgesTo.forEach {
|
addressesToCreateBridgesTo.forEach {
|
||||||
|
@ -169,14 +169,14 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
private inline fun <reified M : SessionMessage> receiveInternal(session: FlowSession): M {
|
private inline fun <reified M : SessionMessage> receiveInternal(session: FlowSession): M {
|
||||||
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)).second
|
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)).message
|
||||||
}
|
}
|
||||||
|
|
||||||
private inline fun <reified M : SessionMessage> sendAndReceiveInternal(session: FlowSession, message: SessionMessage): M {
|
private inline fun <reified M : SessionMessage> sendAndReceiveInternal(session: FlowSession, message: SessionMessage): M {
|
||||||
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)).second
|
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)).message
|
||||||
}
|
}
|
||||||
|
|
||||||
private inline fun <reified M : SessionMessage> sendAndReceiveInternalWithParty(session: FlowSession, message: SessionMessage): Pair<Party, M> {
|
private inline fun <reified M : SessionMessage> sendAndReceiveInternalWithParty(session: FlowSession, message: SessionMessage): ReceivedSessionMessage<M> {
|
||||||
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java))
|
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,8 +215,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
private fun <M : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): Pair<Party, M> {
|
private fun <M : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
|
||||||
fun getReceivedMessage(): Pair<Party, ExistingSessionMessage>? = receiveRequest.session.receivedMessages.poll()
|
fun getReceivedMessage(): ReceivedSessionMessage<ExistingSessionMessage>? = receiveRequest.session.receivedMessages.poll()
|
||||||
|
|
||||||
val polledMessage = getReceivedMessage()
|
val polledMessage = getReceivedMessage()
|
||||||
val receivedMessage = if (polledMessage != null) {
|
val receivedMessage = if (polledMessage != null) {
|
||||||
@ -232,11 +232,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest")
|
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest")
|
||||||
}
|
}
|
||||||
|
|
||||||
if (receivedMessage.second is SessionEnd) {
|
if (receivedMessage.message is SessionEnd) {
|
||||||
openSessions.values.remove(receiveRequest.session)
|
openSessions.values.remove(receiveRequest.session)
|
||||||
throw FlowSessionException("Counterparty on ${receiveRequest.session.state.sendToParty} has prematurely ended on $receiveRequest")
|
throw FlowSessionException("Counterparty on ${receiveRequest.session.state.sendToParty} has prematurely ended on $receiveRequest")
|
||||||
} else if (receiveRequest.receiveType.isInstance(receivedMessage.second)) {
|
} else if (receiveRequest.receiveType.isInstance(receivedMessage.message)) {
|
||||||
return Pair(receivedMessage.first, receiveRequest.receiveType.cast(receivedMessage.second))
|
return ReceivedSessionMessage(receivedMessage.sendingParty, receiveRequest.receiveType.cast(receivedMessage.message))
|
||||||
} else {
|
} else {
|
||||||
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $receiveRequest")
|
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $receiveRequest")
|
||||||
}
|
}
|
||||||
|
@ -244,7 +244,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
if (message is SessionEnd) {
|
if (message is SessionEnd) {
|
||||||
openSessions.remove(message.recipientSessionId)
|
openSessions.remove(message.recipientSessionId)
|
||||||
}
|
}
|
||||||
session.receivedMessages += Pair(otherParty, message)
|
session.receivedMessages += ReceivedSessionMessage(otherParty, message)
|
||||||
if (session.waitingForResponse) {
|
if (session.waitingForResponse) {
|
||||||
// We only want to resume once, so immediately reset the flag.
|
// We only want to resume once, so immediately reset the flag.
|
||||||
session.waitingForResponse = false
|
session.waitingForResponse = false
|
||||||
@ -277,7 +277,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
val psm = createFiber(flow)
|
val psm = createFiber(flow)
|
||||||
val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(otherParty, otherPartySessionId))
|
val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(otherParty, otherPartySessionId))
|
||||||
if (sessionInit.firstPayload != null) {
|
if (sessionInit.firstPayload != null) {
|
||||||
session.receivedMessages += Pair(otherParty, SessionData(session.ourSessionId, sessionInit.firstPayload))
|
session.receivedMessages += ReceivedSessionMessage(otherParty, SessionData(session.ourSessionId, sessionInit.firstPayload))
|
||||||
}
|
}
|
||||||
openSessions[session.ourSessionId] = session
|
openSessions[session.ourSessionId] = session
|
||||||
psm.openSessions[Pair(flow, otherParty)] = session
|
psm.openSessions[Pair(flow, otherParty)] = session
|
||||||
@ -453,6 +453,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
serviceHub.networkService.send(sessionTopic, message, address)
|
serviceHub.networkService.send(sessionTopic, message, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data class ReceivedSessionMessage<out M : SessionMessage>(val sendingParty: Party, val message: M)
|
||||||
|
|
||||||
interface SessionMessage
|
interface SessionMessage
|
||||||
|
|
||||||
@ -509,7 +510,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
var state: FlowSessionState,
|
var state: FlowSessionState,
|
||||||
@Volatile var waitingForResponse: Boolean = false
|
@Volatile var waitingForResponse: Boolean = false
|
||||||
) {
|
) {
|
||||||
val receivedMessages = ConcurrentLinkedQueue<Pair<Party, ExistingSessionMessage>>()
|
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<ExistingSessionMessage>>()
|
||||||
val psm: FlowStateMachineImpl<*> get() = flow.fsm as FlowStateMachineImpl<*>
|
val psm: FlowStateMachineImpl<*> get() = flow.fsm as FlowStateMachineImpl<*>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user