Fixed bug in flow framework with regards to sendAndReceive and session-ends

This commit is contained in:
Shams Asari 2017-08-10 22:20:48 +01:00
parent 008301c4e8
commit 1124383c2a
2 changed files with 50 additions and 17 deletions

View File

@ -348,8 +348,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val polledMessage = pollForMessage() val polledMessage = pollForMessage()
return if (polledMessage != null) { return if (polledMessage != null) {
if (this is SendAndReceive) { if (this is SendAndReceive) {
// We've already received a message but we suspend so that the send can be performed // Since we've already received the message, we downgrade to a send only to get the payload out and not
suspend(this) // inadvertently block
suspend(SendOnly(session, message))
} }
polledMessage polledMessage
} else { } else {

View File

@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.concurrent.Semaphore
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.ContractState import net.corda.core.contracts.ContractState
import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.DOLLARS
@ -62,7 +63,7 @@ class FlowFrameworkTests {
} }
private val mockNet = MockNetwork(servicePeerAllocationStrategy = RoundRobin()) private val mockNet = MockNetwork(servicePeerAllocationStrategy = RoundRobin())
private val sessionTransfers = ArrayList<SessionTransfer>() private val receivedSessionMessages = ArrayList<SessionTransfer>()
private lateinit var node1: MockNode private lateinit var node1: MockNode
private lateinit var node2: MockNode private lateinit var node2: MockNode
private lateinit var notary1: MockNode private lateinit var notary1: MockNode
@ -81,7 +82,7 @@ class FlowFrameworkTests {
notary1 = mockNet.createNotaryNode(networkMapAddress = node1.network.myAddress, overrideServices = overrideServices, serviceName = notaryService.name) notary1 = mockNet.createNotaryNode(networkMapAddress = node1.network.myAddress, overrideServices = overrideServices, serviceName = notaryService.name)
notary2 = mockNet.createNotaryNode(networkMapAddress = node1.network.myAddress, overrideServices = overrideServices, serviceName = notaryService.name) notary2 = mockNet.createNotaryNode(networkMapAddress = node1.network.myAddress, overrideServices = overrideServices, serviceName = notaryService.name)
mockNet.messagingNetwork.receivedMessages.toSessionTransfers().forEach { sessionTransfers += it } receivedSessionMessagesObservable().forEach { receivedSessionMessages += it }
mockNet.runNetwork() mockNet.runNetwork()
// We don't create a network map, so manually handle registrations // We don't create a network map, so manually handle registrations
@ -96,7 +97,7 @@ class FlowFrameworkTests {
@After @After
fun cleanUp() { fun cleanUp() {
mockNet.stopNodes() mockNet.stopNodes()
sessionTransfers.clear() receivedSessionMessages.clear()
} }
@Test @Test
@ -228,7 +229,7 @@ class FlowFrameworkTests {
node2b.smm.executor.flush() node2b.smm.executor.flush()
fut1.getOrThrow() fut1.getOrThrow()
val receivedCount = sessionTransfers.count { it.isPayloadTransfer } val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer }
// Check flows completed cleanly and didn't get out of phase // Check flows completed cleanly and didn't get out of phase
assertEquals(4, receivedCount, "Flow should have exchanged 4 unique messages")// Two messages each way assertEquals(4, receivedCount, "Flow should have exchanged 4 unique messages")// Two messages each way
// can't give a precise value as every addMessageHandler re-runs the undelivered messages // can't give a precise value as every addMessageHandler re-runs the undelivered messages
@ -319,7 +320,8 @@ class FlowFrameworkTests {
node2 sent sessionData(20L) to node1, node2 sent sessionData(20L) to node1,
node1 sent sessionData(11L) to node2, node1 sent sessionData(11L) to node2,
node2 sent sessionData(21L) to node1, node2 sent sessionData(21L) to node1,
node1 sent normalEnd to node2 node1 sent normalEnd to node2,
node2 sent normalEnd to node1
) )
} }
@ -344,7 +346,7 @@ class FlowFrameworkTests {
val notary1Address: MessageRecipients = endpoint.getAddressOfParty(notary1.services.networkMapCache.getPartyInfo(notary1.info.notaryIdentity)!!) val notary1Address: MessageRecipients = endpoint.getAddressOfParty(notary1.services.networkMapCache.getPartyInfo(notary1.info.notaryIdentity)!!)
assertThat(notary1Address).isInstanceOf(InMemoryMessagingNetwork.ServiceHandle::class.java) assertThat(notary1Address).isInstanceOf(InMemoryMessagingNetwork.ServiceHandle::class.java)
assertEquals(notary1Address, endpoint.getAddressOfParty(notary2.services.networkMapCache.getPartyInfo(notary2.info.notaryIdentity)!!)) assertEquals(notary1Address, endpoint.getAddressOfParty(notary2.services.networkMapCache.getPartyInfo(notary2.info.notaryIdentity)!!))
sessionTransfers.expectEvents(isStrict = false) { receivedSessionMessages.expectEvents(isStrict = false) {
sequence( sequence(
// First Pay // First Pay
expect(match = { it.message is SessionInit && it.message.initiatingFlowClass == NotaryFlow.Client::class.java.name }) { expect(match = { it.message is SessionInit && it.message.initiatingFlowClass == NotaryFlow.Client::class.java.name }) {
@ -390,6 +392,32 @@ class FlowFrameworkTests {
}.withMessageContaining(String::class.java.name) // Make sure the exception message mentions the type the flow was expecting to receive }.withMessageContaining(String::class.java.name) // Make sure the exception message mentions the type the flow was expecting to receive
} }
@Test
fun `receiving unexpected session end before entering sendAndReceive`() {
node2.registerFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() }
val sessionEndReceived = Semaphore(0)
receivedSessionMessagesObservable().filter { it.message is SessionEnd }.subscribe { sessionEndReceived.release() }
val resultFuture = node1.services.startFlow(
WaitForOtherSideEndBeforeSendAndReceive(node2.info.legalIdentity, sessionEndReceived)).resultFuture
mockNet.runNetwork()
assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy {
resultFuture.getOrThrow()
}
}
@InitiatingFlow
private class WaitForOtherSideEndBeforeSendAndReceive(val otherParty: Party,
@Transient val receivedOtherFlowEnd: Semaphore) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
// Kick off the flow on the other side ...
send(otherParty, 1)
// ... then pause this one until it's received the session-end message from the other side
receivedOtherFlowEnd.acquire()
sendAndReceive<Int>(otherParty, 2)
}
}
@Test @Test
fun `non-FlowException thrown on other side`() { fun `non-FlowException thrown on other side`() {
val erroringFlowFuture = node2.registerFlowFactory(ReceiveFlow::class) { val erroringFlowFuture = node2.registerFlowFactory(ReceiveFlow::class) {
@ -456,7 +484,7 @@ class FlowFrameworkTests {
node2 sent erroredEnd(erroringFlow.get().exceptionThrown) to node1 node2 sent erroredEnd(erroringFlow.get().exceptionThrown) to node1
) )
// Make sure the original stack trace isn't sent down the wire // Make sure the original stack trace isn't sent down the wire
assertThat((sessionTransfers.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty() assertThat((receivedSessionMessages.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty()
} }
@Test @Test
@ -631,7 +659,7 @@ class FlowFrameworkTests {
node2.registerFlowFactory(UpgradedFlow::class, initiatedFlowVersion = 1) { SendFlow("Old initiated", it) } node2.registerFlowFactory(UpgradedFlow::class, initiatedFlowVersion = 1) { SendFlow("Old initiated", it) }
val result = node1.services.startFlow(UpgradedFlow(node2.info.legalIdentity)).resultFuture val result = node1.services.startFlow(UpgradedFlow(node2.info.legalIdentity)).resultFuture
mockNet.runNetwork() mockNet.runNetwork()
assertThat(sessionTransfers).startsWith( assertThat(receivedSessionMessages).startsWith(
node1 sent sessionInit(UpgradedFlow::class, flowVersion = 2) to node2, node1 sent sessionInit(UpgradedFlow::class, flowVersion = 2) to node2,
node2 sent sessionConfirm(flowVersion = 1) to node1 node2 sent sessionConfirm(flowVersion = 1) to node1
) )
@ -646,7 +674,7 @@ class FlowFrameworkTests {
val initiatingFlow = SendFlow("Old initiating", node2.info.legalIdentity) val initiatingFlow = SendFlow("Old initiating", node2.info.legalIdentity)
node1.services.startFlow(initiatingFlow) node1.services.startFlow(initiatingFlow)
mockNet.runNetwork() mockNet.runNetwork()
assertThat(sessionTransfers).startsWith( assertThat(receivedSessionMessages).startsWith(
node1 sent sessionInit(SendFlow::class, flowVersion = 1, payload = "Old initiating") to node2, node1 sent sessionInit(SendFlow::class, flowVersion = 1, payload = "Old initiating") to node2,
node2 sent sessionConfirm(flowVersion = 2) to node1 node2 sent sessionConfirm(flowVersion = 2) to node1
) )
@ -666,8 +694,8 @@ class FlowFrameworkTests {
fun `unknown class in session init`() { fun `unknown class in session init`() {
node1.sendSessionMessage(SessionInit(random63BitValue(), "not.a.real.Class", 1, "version", null), node2) node1.sendSessionMessage(SessionInit(random63BitValue(), "not.a.real.Class", 1, "version", null), node2)
mockNet.runNetwork() mockNet.runNetwork()
assertThat(sessionTransfers).hasSize(2) // Only the session-init and session-reject are expected assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected
val reject = sessionTransfers.last().message as SessionReject val reject = receivedSessionMessages.last().message as SessionReject
assertThat(reject.errorMessage).isEqualTo("Don't know not.a.real.Class") assertThat(reject.errorMessage).isEqualTo("Don't know not.a.real.Class")
} }
@ -675,8 +703,8 @@ class FlowFrameworkTests {
fun `non-flow class in session init`() { fun `non-flow class in session init`() {
node1.sendSessionMessage(SessionInit(random63BitValue(), String::class.java.name, 1, "version", null), node2) node1.sendSessionMessage(SessionInit(random63BitValue(), String::class.java.name, 1, "version", null), node2)
mockNet.runNetwork() mockNet.runNetwork()
assertThat(sessionTransfers).hasSize(2) // Only the session-init and session-reject are expected assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected
val reject = sessionTransfers.last().message as SessionReject val reject = receivedSessionMessages.last().message as SessionReject
assertThat(reject.errorMessage).isEqualTo("${String::class.java.name} is not a flow") assertThat(reject.errorMessage).isEqualTo("${String::class.java.name} is not a flow")
} }
@ -743,11 +771,11 @@ class FlowFrameworkTests {
} }
private fun assertSessionTransfers(vararg expected: SessionTransfer) { private fun assertSessionTransfers(vararg expected: SessionTransfer) {
assertThat(sessionTransfers).containsExactly(*expected) assertThat(receivedSessionMessages).containsExactly(*expected)
} }
private fun assertSessionTransfers(node: MockNode, vararg expected: SessionTransfer): List<SessionTransfer> { private fun assertSessionTransfers(node: MockNode, vararg expected: SessionTransfer): List<SessionTransfer> {
val actualForNode = sessionTransfers.filter { it.from == node.id || it.to == node.network.myAddress } val actualForNode = receivedSessionMessages.filter { it.from == node.id || it.to == node.network.myAddress }
assertThat(actualForNode).containsExactly(*expected) assertThat(actualForNode).containsExactly(*expected)
return actualForNode return actualForNode
} }
@ -757,6 +785,10 @@ class FlowFrameworkTests {
override fun toString(): String = "$from sent $message to $to" override fun toString(): String = "$from sent $message to $to"
} }
private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> {
return mockNet.messagingNetwork.receivedMessages.toSessionTransfers()
}
private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> { private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map { return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map {
val from = it.sender.id val from = it.sender.id