From 91f013c12704422f9eaedccfb8d0c6728cadc17f Mon Sep 17 00:00:00 2001 From: Maksymilian Pawlak Date: Tue, 8 May 2018 14:00:57 +0100 Subject: [PATCH] Flow framework test optimisation (#3031) * Flow framework test separation into classes, so the one which do not require nodes restart can execute faster. --- .../statemachine/FlowFrameworkTests.kt | 1296 +++++++++-------- 1 file changed, 717 insertions(+), 579 deletions(-) diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index e4512c01af..6aa94f993a 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -40,10 +40,7 @@ import net.corda.testing.node.internal.InternalMockNetwork.MockNode import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType -import org.junit.After -import org.junit.Before -import org.junit.Ignore -import org.junit.Test +import org.junit.* import rx.Notification import rx.Observable import java.time.Instant @@ -58,47 +55,50 @@ class FlowFrameworkTests { init { LogHelper.setLevel("+net.corda.flow") } - } - private lateinit var mockNet: InternalMockNetwork - private val receivedSessionMessages = ArrayList() - private lateinit var aliceNode: StartedNode - private lateinit var bobNode: StartedNode - private lateinit var notaryIdentity: Party - private lateinit var alice: Party - private lateinit var bob: Party + private lateinit var mockNet: InternalMockNetwork + private lateinit var aliceNode: StartedNode + private lateinit var bobNode: StartedNode + private lateinit var alice: Party + private lateinit var bob: Party + private lateinit var notaryIdentity: Party + private val receivedSessionMessages = ArrayList() - @Before - fun start() { - mockNet = InternalMockNetwork( - cordappPackages = listOf("net.corda.finance.contracts", "net.corda.testing.contracts"), - servicePeerAllocationStrategy = RoundRobin() - ) - aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) - bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME)) + @BeforeClass + @JvmStatic + fun beforeClass() { + mockNet = InternalMockNetwork( + cordappPackages = listOf("net.corda.finance.contracts", "net.corda.testing.contracts"), + servicePeerAllocationStrategy = RoundRobin() + ) - receivedSessionMessagesObservable().forEach { receivedSessionMessages += it } + aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) + bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME)) + + // Extract identities + alice = aliceNode.info.singleIdentity() + bob = bobNode.info.singleIdentity() + notaryIdentity = mockNet.defaultNotaryIdentity + + receivedSessionMessagesObservable().forEach { receivedSessionMessages += it } + } + + private fun receivedSessionMessagesObservable(): Observable { + return mockNet.messagingNetwork.receivedMessages.toSessionTransfers() + } + + @AfterClass @JvmStatic + fun afterClass() { + mockNet.stopNodes() + } - // Extract identities - alice = aliceNode.info.singleIdentity() - bob = bobNode.info.singleIdentity() - notaryIdentity = mockNet.defaultNotaryIdentity } @After fun cleanUp() { - mockNet.stopNodes() receivedSessionMessages.clear() } - @Test - fun `newly added flow is preserved on restart`() { - aliceNode.services.startFlow(NoOpFlow(nonTerminating = true)) - aliceNode.internals.acceptableLiveFiberCountOnStop = 1 - val restoredFlow = aliceNode.restartAndGetRestoredFlow() - assertThat(restoredFlow.flowStarted).isTrue() - } - @Test fun `flow can lazily use the serviceHub in its constructor`() { val flow = LazyServiceHubAccessFlow() @@ -106,19 +106,6 @@ class FlowFrameworkTests { assertThat(flow.lazyTime).isNotNull() } - class ThrowingActionExecutor(private val exception: Exception, val delegate: ActionExecutor) : ActionExecutor { - var thrown = false - @Suspendable - override fun executeAction(fiber: FlowFiber, action: Action) { - if (thrown) { - delegate.executeAction(fiber, action) - } else { - thrown = true - throw exception - } - } - } - @Test fun `exception while fiber suspended`() { bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } @@ -137,143 +124,6 @@ class FlowFrameworkTests { assertThat(fiber.state).isEqualTo(Strand.State.WAITING) } - @Test - fun `flow restarted just after receiving payload`() { - bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } - aliceNode.services.startFlow(SendFlow("Hello", bob)) - - // We push through just enough messages to get only the payload sent - bobNode.pumpReceive() - bobNode.internals.disableDBCloseOnStop() - bobNode.internals.acceptableLiveFiberCountOnStop = 1 - bobNode.dispose() - mockNet.runNetwork() - val restoredFlow = bobNode.restartAndGetRestoredFlow() - assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello") - } - - @Test - fun `flow loaded from checkpoint will respond to messages from before start`() { - aliceNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } - bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow - // Make sure the add() has finished initial processing. - bobNode.internals.disableDBCloseOnStop() - bobNode.dispose() // kill receiver - val restoredFlow = bobNode.restartAndGetRestoredFlow() - assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello") - } - - @Ignore("Some changes in startup order make this test's assumptions fail.") - @Test - fun `flow with send will resend on interrupted restart`() { - val payload = random63BitValue() - val payload2 = random63BitValue() - - var sentCount = 0 - mockNet.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ } - val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) - val secondFlow = charlieNode.registerFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) } - mockNet.runNetwork() - val charlie = charlieNode.info.singleIdentity() - - // Kick off first send and receive - bobNode.services.startFlow(PingPongFlow(charlie, payload)) - bobNode.database.transaction { - assertEquals(1, bobNode.checkpointStorage.checkpoints().size) - } - // Make sure the add() has finished initial processing. - bobNode.internals.disableDBCloseOnStop() - // Restart node and thus reload the checkpoint and resend the message with same UUID - bobNode.dispose() - bobNode.database.transaction { - assertEquals(1, bobNode.checkpointStorage.checkpoints().size) // confirm checkpoint - bobNode.services.networkMapCache.clearNetworkMapCache() - } - val node2b = mockNet.createNode(InternalMockNodeParameters(bobNode.internals.id)) - bobNode.internals.manuallyCloseDB() - val (firstAgain, fut1) = node2b.getSingleFlow() - // Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync. - mockNet.runNetwork() - fut1.getOrThrow() - - val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer } - // Check flows completed cleanly and didn't get out of phase - assertEquals(4, receivedCount, "Flow should have exchanged 4 unique messages")// Two messages each way - // can't give a precise value as every addMessageHandler re-runs the undelivered messages - assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") - node2b.database.transaction { - assertEquals(0, node2b.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended") - } - charlieNode.database.transaction { - assertEquals(0, charlieNode.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended") - } - assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3") - assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3") - assertEquals(payload, secondFlow.getOrThrow().receivedPayload, "Received payload does not match the (restarted) first value on Node 2") - assertEquals(payload + 1, secondFlow.getOrThrow().receivedPayload2, "Received payload does not match the expected second value on Node 2") - } - - @Test - fun `sending to multiple parties`() { - val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) - val charlie = charlieNode.info.singleIdentity() - bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } - charlieNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } - val payload = "Hello World" - aliceNode.services.startFlow(SendFlow(payload, bob, charlie)) - mockNet.runNetwork() - bobNode.internals.acceptableLiveFiberCountOnStop = 1 - charlieNode.internals.acceptableLiveFiberCountOnStop = 1 - val bobFlow = bobNode.getSingleFlow().first - val charlieFlow = charlieNode.getSingleFlow().first - assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload) - assertThat(charlieFlow.receivedPayloads[0]).isEqualTo(payload) - - assertSessionTransfers(bobNode, - aliceNode sent sessionInit(SendFlow::class, payload = payload) to bobNode, - bobNode sent sessionConfirm() to aliceNode, - aliceNode sent normalEnd to bobNode - //There's no session end from the other flows as they're manually suspended - ) - - assertSessionTransfers(charlieNode, - aliceNode sent sessionInit(SendFlow::class, payload = payload) to charlieNode, - charlieNode sent sessionConfirm() to aliceNode, - aliceNode sent normalEnd to charlieNode - //There's no session end from the other flows as they're manually suspended - ) - } - - @Test - fun `receiving from multiple parties`() { - val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) - val charlie = charlieNode.info.singleIdentity() - val bobPayload = "Test 1" - val charliePayload = "Test 2" - bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) } - charlieNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) } - val multiReceiveFlow = ReceiveFlow(bob, charlie).nonTerminating() - aliceNode.services.startFlow(multiReceiveFlow) - aliceNode.internals.acceptableLiveFiberCountOnStop = 1 - mockNet.runNetwork() - assertThat(multiReceiveFlow.receivedPayloads[0]).isEqualTo(bobPayload) - assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(charliePayload) - - assertSessionTransfers(bobNode, - aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, - bobNode sent sessionConfirm() to aliceNode, - bobNode sent sessionData(bobPayload) to aliceNode, - bobNode sent normalEnd to aliceNode - ) - - assertSessionTransfers(charlieNode, - aliceNode sent sessionInit(ReceiveFlow::class) to charlieNode, - charlieNode sent sessionConfirm() to aliceNode, - charlieNode sent sessionData(charliePayload) to aliceNode, - charlieNode sent normalEnd to aliceNode - ) - } - @Test fun `both sides do a send as their first IO request`() { bobNode.registerFlowFactory(PingPongFlow::class) { PingPongFlow(it, 20L) } @@ -316,56 +166,6 @@ class FlowFrameworkTests { } } - @InitiatingFlow - private class WaitForOtherSideEndBeforeSendAndReceive(val otherParty: Party, - @Transient val receivedOtherFlowEnd: Semaphore) : FlowLogic() { - @Suspendable - override fun call() { - // Kick off the flow on the other side ... - val session = initiateFlow(otherParty) - session.send(1) - // ... then pause this one until it's received the session-end message from the other side - receivedOtherFlowEnd.acquire() - session.sendAndReceive(2) - } - } - - @Test - fun `non-FlowException thrown on other side`() { - val erroringFlowFuture = bobNode.registerFlowFactory(ReceiveFlow::class) { - ExceptionFlow { Exception("evil bug!") } - } - val erroringFlowSteps = erroringFlowFuture.flatMap { it.progressSteps } - - val receiveFlow = ReceiveFlow(bob) - val receiveFlowSteps = receiveFlow.progressSteps - val receiveFlowResult = aliceNode.services.startFlow(receiveFlow).resultFuture - - mockNet.runNetwork() - - erroringFlowFuture.getOrThrow() - val flowSteps = erroringFlowSteps.get() - assertThat(flowSteps).containsExactly( - Notification.createOnNext(ExceptionFlow.START_STEP), - Notification.createOnError(erroringFlowFuture.get().exceptionThrown) - ) - - val receiveFlowException = assertFailsWith(UnexpectedFlowEndException::class) { - receiveFlowResult.getOrThrow() - } - assertThat(receiveFlowException.message).doesNotContain("evil bug!") - assertThat(receiveFlowSteps.get()).containsExactly( - Notification.createOnNext(ReceiveFlow.START_STEP), - Notification.createOnError(receiveFlowException) - ) - - assertSessionTransfers( - aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, - bobNode sent sessionConfirm() to aliceNode, - bobNode sent errorMessage() to aliceNode - ) - } - @Test fun `FlowException thrown on other side`() { val erroringFlow = bobNode.registerFlowFactory(ReceiveFlow::class) { @@ -402,52 +202,6 @@ class FlowFrameworkTests { assertThat((lastMessage.payload as ErrorSessionMessage).flowException!!.stackTrace).isEmpty() } - @Test - fun `FlowException only propagated to parent`() { - val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) - val charlie = charlieNode.info.singleIdentity() - - charlieNode.registerFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } } - bobNode.registerFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) } - val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob)) - mockNet.runNetwork() - assertThatExceptionOfType(UnexpectedFlowEndException::class.java) - .isThrownBy { receivingFiber.resultFuture.getOrThrow() } - } - - @Test - fun `FlowException thrown and there is a 3rd unrelated party flow`() { - val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) - val charlie = charlieNode.info.singleIdentity() - - // Bob will send its payload and then block waiting for the receive from Alice. Meanwhile Alice will move - // onto Charlie which will throw the exception - val node2Fiber = bobNode - .registerFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") } - .map { it.stateMachine } - charlieNode.registerFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } } - - val aliceFiber = aliceNode.services.startFlow(ReceiveFlow(bob, charlie)) as FlowStateMachineImpl - mockNet.runNetwork() - - // Alice will terminate with the error it received from Charlie but it won't propagate that to Bob (as it's - // not relevant to it) but it will end its session with it - assertThatExceptionOfType(MyFlowException::class.java).isThrownBy { - aliceFiber.resultFuture.getOrThrow() - } - val bobResultFuture = node2Fiber.getOrThrow().resultFuture - assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { - bobResultFuture.getOrThrow() - } - - assertSessionTransfers(bobNode, - aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, - bobNode sent sessionConfirm() to aliceNode, - bobNode sent sessionData("Hello") to aliceNode, - aliceNode sent errorMessage() to bobNode - ) - } - private class ConditionalExceptionFlow(val otherPartySession: FlowSession, val sendPayload: Any) : FlowLogic() { @Suspendable override fun call() { @@ -598,11 +352,11 @@ class FlowFrameworkTests { @Test fun `unregistered flow`() { - val future = aliceNode.services.startFlow(SendFlow("Hello", bob)).resultFuture + val future = aliceNode.services.startFlow(NeverRegisteredFlow("Hello", bob)).resultFuture mockNet.runNetwork() assertThatExceptionOfType(UnexpectedFlowEndException::class.java) .isThrownBy { future.getOrThrow() } - .withMessageEndingWith("${SendFlow::class.java.name} is not registered") + .withMessageEndingWith("${NeverRegisteredFlow::class.java.name} is not registered") } @Test @@ -639,6 +393,349 @@ class FlowFrameworkTests { assertThat(result.getOrThrow()).isEqualTo("HelloHello") } + @Test + fun `non-FlowException thrown on other side`() { + val erroringFlowFuture = bobNode.registerFlowFactory(ReceiveFlow::class) { + ExceptionFlow { Exception("evil bug!") } + } + val erroringFlowSteps = erroringFlowFuture.flatMap { it.progressSteps } + + val receiveFlow = ReceiveFlow(bob) + val receiveFlowSteps = receiveFlow.progressSteps + val receiveFlowResult = aliceNode.services.startFlow(receiveFlow).resultFuture + + mockNet.runNetwork() + + erroringFlowFuture.getOrThrow() + val flowSteps = erroringFlowSteps.get() + assertThat(flowSteps).containsExactly( + Notification.createOnNext(ExceptionFlow.START_STEP), + Notification.createOnError(erroringFlowFuture.get().exceptionThrown) + ) + + val receiveFlowException = assertFailsWith(UnexpectedFlowEndException::class) { + receiveFlowResult.getOrThrow() + } + assertThat(receiveFlowException.message).doesNotContain("evil bug!") + assertThat(receiveFlowSteps.get()).containsExactly( + Notification.createOnNext(ReceiveFlow.START_STEP), + Notification.createOnError(receiveFlowException) + ) + + assertSessionTransfers( + aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, + bobNode sent sessionConfirm() to aliceNode, + bobNode sent errorMessage() to aliceNode + ) + } + + //region Helpers + + private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) + + private fun StartedNode<*>.sendSessionMessage(message: SessionMessage, destination: Party) { + services.networkService.apply { + val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList())) + send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address) + } + } + + private fun assertSessionTransfers(vararg expected: SessionTransfer) { + assertThat(receivedSessionMessages).containsExactly(*expected) + } + + //endregion Helpers +} + +class FlowFrameworkTripartyTests { + + companion object { + init { + LogHelper.setLevel("+net.corda.flow") + } + + private lateinit var mockNet: InternalMockNetwork + private lateinit var aliceNode: StartedNode + private lateinit var bobNode: StartedNode + private lateinit var charlieNode: StartedNode + private lateinit var alice: Party + private lateinit var bob: Party + private lateinit var charlie: Party + private lateinit var notaryIdentity: Party + private val receivedSessionMessages = ArrayList() + + @BeforeClass + @JvmStatic + fun beforeClass() { + mockNet = InternalMockNetwork( + cordappPackages = listOf("net.corda.finance.contracts", "net.corda.testing.contracts"), + servicePeerAllocationStrategy = RoundRobin() + ) + + aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) + bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME)) + charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) + + + // Extract identities + alice = aliceNode.info.singleIdentity() + bob = bobNode.info.singleIdentity() + charlie = charlieNode.info.singleIdentity() + notaryIdentity = mockNet.defaultNotaryIdentity + + receivedSessionMessagesObservable().forEach { receivedSessionMessages += it } + } + + @AfterClass @JvmStatic + fun afterClass() { + mockNet.stopNodes() + } + + private fun receivedSessionMessagesObservable(): Observable { + return mockNet.messagingNetwork.receivedMessages.toSessionTransfers() + } + + } + + @After + fun cleanUp() { + receivedSessionMessages.clear() + } + + + @Test + fun `sending to multiple parties`() { + bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } + charlieNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } + val payload = "Hello World" + aliceNode.services.startFlow(SendFlow(payload, bob, charlie)) + mockNet.runNetwork() + bobNode.internals.acceptableLiveFiberCountOnStop = 1 + charlieNode.internals.acceptableLiveFiberCountOnStop = 1 + val bobFlow = bobNode.getSingleFlow().first + val charlieFlow = charlieNode.getSingleFlow().first + assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload) + assertThat(charlieFlow.receivedPayloads[0]).isEqualTo(payload) + + assertSessionTransfers(bobNode, + aliceNode sent sessionInit(SendFlow::class, payload = payload) to bobNode, + bobNode sent sessionConfirm() to aliceNode, + aliceNode sent normalEnd to bobNode + //There's no session end from the other flows as they're manually suspended + ) + + assertSessionTransfers(charlieNode, + aliceNode sent sessionInit(SendFlow::class, payload = payload) to charlieNode, + charlieNode sent sessionConfirm() to aliceNode, + aliceNode sent normalEnd to charlieNode + //There's no session end from the other flows as they're manually suspended + ) + } + + @Test + fun `receiving from multiple parties`() { + val bobPayload = "Test 1" + val charliePayload = "Test 2" + bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) } + charlieNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) } + val multiReceiveFlow = ReceiveFlow(bob, charlie).nonTerminating() + aliceNode.services.startFlow(multiReceiveFlow) + aliceNode.internals.acceptableLiveFiberCountOnStop = 1 + mockNet.runNetwork() + assertThat(multiReceiveFlow.receivedPayloads[0]).isEqualTo(bobPayload) + assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(charliePayload) + + assertSessionTransfers(bobNode, + aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, + bobNode sent sessionConfirm() to aliceNode, + bobNode sent sessionData(bobPayload) to aliceNode, + bobNode sent normalEnd to aliceNode + ) + + assertSessionTransfers(charlieNode, + aliceNode sent sessionInit(ReceiveFlow::class) to charlieNode, + charlieNode sent sessionConfirm() to aliceNode, + charlieNode sent sessionData(charliePayload) to aliceNode, + charlieNode sent normalEnd to aliceNode + ) + } + + @Test + fun `FlowException only propagated to parent`() { + charlieNode.registerFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } } + bobNode.registerFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) } + val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob)) + mockNet.runNetwork() + assertThatExceptionOfType(UnexpectedFlowEndException::class.java) + .isThrownBy { receivingFiber.resultFuture.getOrThrow() } + } + + @Test + fun `FlowException thrown and there is a 3rd unrelated party flow`() { + // Bob will send its payload and then block waiting for the receive from Alice. Meanwhile Alice will move + // onto Charlie which will throw the exception + val node2Fiber = bobNode + .registerFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") } + .map { it.stateMachine } + charlieNode.registerFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } } + + val aliceFiber = aliceNode.services.startFlow(ReceiveFlow(bob, charlie)) as FlowStateMachineImpl + mockNet.runNetwork() + + // Alice will terminate with the error it received from Charlie but it won't propagate that to Bob (as it's + // not relevant to it) but it will end its session with it + assertThatExceptionOfType(MyFlowException::class.java).isThrownBy { + aliceFiber.resultFuture.getOrThrow() + } + val bobResultFuture = node2Fiber.getOrThrow().resultFuture + assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { + bobResultFuture.getOrThrow() + } + + assertSessionTransfers(bobNode, + aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, + bobNode sent sessionConfirm() to aliceNode, + bobNode sent sessionData("Hello") to aliceNode, + aliceNode sent errorMessage() to bobNode + ) + } + + private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) + + private fun assertSessionTransfers(vararg expected: SessionTransfer) { + assertThat(receivedSessionMessages).containsExactly(*expected) + } + + private fun assertSessionTransfers(node: StartedNode, vararg expected: SessionTransfer): List { + val actualForNode = receivedSessionMessages.filter { it.from == node.internals.id || it.to == node.network.myAddress } + assertThat(actualForNode).containsExactly(*expected) + return actualForNode + } + +} + +class FlowFrameworkPersistenceTests { + companion object { + init { + LogHelper.setLevel("+net.corda.flow") + } + } + + private lateinit var mockNet: InternalMockNetwork + private val receivedSessionMessages = ArrayList() + private lateinit var aliceNode: StartedNode + private lateinit var bobNode: StartedNode + private lateinit var notaryIdentity: Party + private lateinit var alice: Party + private lateinit var bob: Party + + @Before + fun start() { + mockNet = InternalMockNetwork( + cordappPackages = listOf("net.corda.finance.contracts", "net.corda.testing.contracts"), + servicePeerAllocationStrategy = RoundRobin() + ) + aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) + bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME)) + + receivedSessionMessagesObservable().forEach { receivedSessionMessages += it } + + // Extract identities + alice = aliceNode.info.singleIdentity() + bob = bobNode.info.singleIdentity() + notaryIdentity = mockNet.defaultNotaryIdentity + } + + @After + fun cleanUp() { + mockNet.stopNodes() + receivedSessionMessages.clear() + } + + @Test + fun `newly added flow is preserved on restart`() { + aliceNode.services.startFlow(NoOpFlow(nonTerminating = true)) + aliceNode.internals.acceptableLiveFiberCountOnStop = 1 + val restoredFlow = aliceNode.restartAndGetRestoredFlow() + assertThat(restoredFlow.flowStarted).isTrue() + } + + @Test + fun `flow restarted just after receiving payload`() { + bobNode.registerFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } + aliceNode.services.startFlow(SendFlow("Hello", bob)) + + // We push through just enough messages to get only the payload sent + bobNode.pumpReceive() + bobNode.internals.disableDBCloseOnStop() + bobNode.internals.acceptableLiveFiberCountOnStop = 1 + bobNode.dispose() + mockNet.runNetwork() + val restoredFlow = bobNode.restartAndGetRestoredFlow() + assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello") + } + + @Test + fun `flow loaded from checkpoint will respond to messages from before start`() { + aliceNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } + bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow + // Make sure the add() has finished initial processing. + bobNode.internals.disableDBCloseOnStop() + bobNode.dispose() // kill receiver + val restoredFlow = bobNode.restartAndGetRestoredFlow() + assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello") + } + + @Ignore("Some changes in startup order make this test's assumptions fail.") + @Test + fun `flow with send will resend on interrupted restart`() { + val payload = random63BitValue() + val payload2 = random63BitValue() + + var sentCount = 0 + mockNet.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ } + val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) + val secondFlow = charlieNode.registerFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) } + mockNet.runNetwork() + val charlie = charlieNode.info.singleIdentity() + + // Kick off first send and receive + bobNode.services.startFlow(PingPongFlow(charlie, payload)) + bobNode.database.transaction { + assertEquals(1, bobNode.checkpointStorage.checkpoints().size) + } + // Make sure the add() has finished initial processing. + bobNode.internals.disableDBCloseOnStop() + // Restart node and thus reload the checkpoint and resend the message with same UUID + bobNode.dispose() + bobNode.database.transaction { + assertEquals(1, bobNode.checkpointStorage.checkpoints().size) // confirm checkpoint + bobNode.services.networkMapCache.clearNetworkMapCache() + } + val node2b = mockNet.createNode(InternalMockNodeParameters(bobNode.internals.id)) + bobNode.internals.manuallyCloseDB() + val (firstAgain, fut1) = node2b.getSingleFlow() + // Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync. + mockNet.runNetwork() + fut1.getOrThrow() + + val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer } + // Check flows completed cleanly and didn't get out of phase + assertEquals(4, receivedCount, "Flow should have exchanged 4 unique messages")// Two messages each way + // can't give a precise value as every addMessageHandler re-runs the undelivered messages + assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") + node2b.database.transaction { + assertEquals(0, node2b.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended") + } + charlieNode.database.transaction { + assertEquals(0, charlieNode.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended") + } + assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3") + assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3") + assertEquals(payload, secondFlow.getOrThrow().receivedPayload, "Received payload does not match the (restarted) first value on Node 2") + assertEquals(payload + 1, secondFlow.getOrThrow().receivedPayload2, "Received payload does not match the expected second value on Node 2") + } + //////////////////////////////////////////////////////////////////////////////////////////////////////////// //region Helpers @@ -652,37 +749,6 @@ class FlowFrameworkTests { newNode.getSingleFlow

().first } - private inline fun > StartedNode<*>.getSingleFlow(): Pair> { - return smm.findStateMachines(P::class.java).single() - } - - private inline fun > StartedNode<*>.registerFlowFactory( - initiatingFlowClass: KClass>, - initiatedFlowVersion: Int = 1, - noinline flowFactory: (FlowSession) -> P): CordaFuture

{ - val observable = internalRegisterFlowFactory( - initiatingFlowClass.java, - InitiatedFlowFactory.CorDapp(initiatedFlowVersion, "", flowFactory), - P::class.java, - track = true) - return observable.toFuture() - } - - private fun sessionInit(clientFlowClass: KClass>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage { - return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) - } - private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, ""))) - private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize())) - private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) - private fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0)) - - private fun StartedNode<*>.sendSessionMessage(message: SessionMessage, destination: Party) { - services.networkService.apply { - val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList())) - send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address) - } - } - private fun assertSessionTransfers(vararg expected: SessionTransfer) { assertThat(receivedSessionMessages).containsExactly(*expected) } @@ -693,275 +759,347 @@ class FlowFrameworkTests { return actualForNode } - private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) { - val isPayloadTransfer: Boolean get() = - message is ExistingSessionMessage && message.payload is DataSessionMessage || - message is InitialSessionMessage && message.firstPayload != null - override fun toString(): String = "$from sent $message to $to" - } - private fun receivedSessionMessagesObservable(): Observable { return mockNet.messagingNetwork.receivedMessages.toSessionTransfers() } - private fun Observable.toSessionTransfers(): Observable { - return filter { it.getMessage().topic == FlowMessagingImpl.sessionTopic }.map { - val from = it.sender.id - val message = it.messageData.deserialize() - SessionTransfer(from, sanitise(message), it.recipients) - } - } - - private fun sanitise(message: SessionMessage) = when (message) { - is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(0), initiationEntropy = 0, appName = "") - is ExistingSessionMessage -> { - val payload = message.payload - message.copy( - recipientSessionId = SessionId(0), - payload = when (payload) { - is ConfirmSessionMessage -> payload.copy( - initiatedSessionId = SessionId(0), - initiatedFlowInfo = payload.initiatedFlowInfo.copy(appName = "") - ) - is ErrorSessionMessage -> payload.copy( - errorId = 0 - ) - else -> payload - } - ) - } - } - - private infix fun StartedNode.sent(message: SessionMessage): Pair = Pair(internals.id, message) - private infix fun Pair.to(node: StartedNode<*>): SessionTransfer = SessionTransfer(first, second, node.network.myAddress) - - private val FlowLogic<*>.progressSteps: CordaFuture>> - get() { - return progressTracker!!.changes - .ofType(Change.Position::class.java) - .map { it.newStep } - .materialize() - .toList() - .toFuture() - } - - private class LazyServiceHubAccessFlow : FlowLogic() { - val lazyTime: Instant by lazy { serviceHub.clock.instant() } - @Suspendable - override fun call() = Unit - } - - private class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic() { - @Transient - var flowStarted = false - - @Suspendable - override fun call() { - flowStarted = true - if (nonTerminating) { - Fiber.park() - } - } - } - - @InitiatingFlow - private open class SendFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic() { - init { - require(otherParties.isNotEmpty()) - } - - @Suspendable - override fun call(): FlowInfo { - val flowInfos = otherParties.map { - val session = initiateFlow(it) - session.send(payload) - session.getCounterpartyFlowInfo() - }.toList() - return flowInfos.first() - } - } - - private open class InitiatedSendFlow(val payload: Any, val otherPartySession: FlowSession) : FlowLogic() { - @Suspendable - override fun call() = otherPartySession.send(payload) - } - - private interface CustomInterface - - private class CustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty) - - @InitiatingFlow - private class IncorrectCustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty) - - @InitiatingFlow - private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic() { - object START_STEP : ProgressTracker.Step("Starting") - object RECEIVED_STEP : ProgressTracker.Step("Received") - - init { - require(otherParties.isNotEmpty()) - } - - override val progressTracker: ProgressTracker = ProgressTracker(START_STEP, RECEIVED_STEP) - private var nonTerminating: Boolean = false - @Transient - var receivedPayloads: List = emptyList() - - @Suspendable - override fun call() { - progressTracker.currentStep = START_STEP - receivedPayloads = otherParties.map { initiateFlow(it).receive().unwrap { it } } - progressTracker.currentStep = RECEIVED_STEP - if (nonTerminating) { - Fiber.park() - } - } - - fun nonTerminating(): ReceiveFlow { - nonTerminating = true - return this - } - } - - private class InitiatedReceiveFlow(val otherPartySession: FlowSession) : FlowLogic() { - object START_STEP : ProgressTracker.Step("Starting") - object RECEIVED_STEP : ProgressTracker.Step("Received") - - override val progressTracker: ProgressTracker = ProgressTracker(START_STEP, RECEIVED_STEP) - private var nonTerminating: Boolean = false - @Transient - var receivedPayloads: List = emptyList() - - @Suspendable - override fun call() { - progressTracker.currentStep = START_STEP - receivedPayloads = listOf(otherPartySession.receive().unwrap { it }) - progressTracker.currentStep = RECEIVED_STEP - if (nonTerminating) { - Fiber.park() - } - } - - fun nonTerminating(): InitiatedReceiveFlow { - nonTerminating = true - return this - } - } - - @InitiatingFlow - private class SendAndReceiveFlow(val otherParty: Party, val payload: Any, val otherPartySession: FlowSession? = null) : FlowLogic() { - constructor(otherPartySession: FlowSession, payload: Any) : this(otherPartySession.counterparty, payload, otherPartySession) - - @Suspendable - override fun call(): Any = (otherPartySession ?: initiateFlow(otherParty)).sendAndReceive(payload).unwrap { it } - } - - private class InlinedSendFlow(val payload: String, val otherPartySession: FlowSession) : FlowLogic() { - @Suspendable - override fun call() = otherPartySession.send(payload) - } - - @InitiatingFlow - private class PingPongFlow(val otherParty: Party, val payload: Long, val otherPartySession: FlowSession? = null) : FlowLogic() { - constructor(otherPartySession: FlowSession, payload: Long) : this(otherPartySession.counterparty, payload, otherPartySession) - - @Transient - var receivedPayload: Long? = null - @Transient - var receivedPayload2: Long? = null - - @Suspendable - override fun call() { - val session = otherPartySession ?: initiateFlow(otherParty) - receivedPayload = session.sendAndReceive(payload).unwrap { it } - receivedPayload2 = session.sendAndReceive(payload + 1).unwrap { it } - } - } - - private class ExceptionFlow(val exception: () -> E) : FlowLogic() { - object START_STEP : ProgressTracker.Step("Starting") - - override val progressTracker: ProgressTracker = ProgressTracker(START_STEP) - lateinit var exceptionThrown: E - - @Suspendable - override fun call(): Nothing { - progressTracker.currentStep = START_STEP - exceptionThrown = exception() - throw exceptionThrown - } - } - - private class MyFlowException(override val message: String) : FlowException() { - override fun equals(other: Any?): Boolean = other is MyFlowException && other.message == this.message - override fun hashCode(): Int = message.hashCode() - } - - private object WaitingFlows { - @InitiatingFlow - class Waiter(val stx: SignedTransaction, val otherParty: Party) : FlowLogic() { - @Suspendable - override fun call(): SignedTransaction { - val otherPartySession = initiateFlow(otherParty) - otherPartySession.send(stx) - return waitForLedgerCommit(stx.id) - } - } - - class Committer(val otherPartySession: FlowSession, val throwException: (() -> Exception)? = null) : FlowLogic() { - @Suspendable - override fun call(): SignedTransaction { - val stx = otherPartySession.receive().unwrap { it } - if (throwException != null) throw throwException.invoke() - return subFlow(FinalityFlow(stx, setOf(otherPartySession.counterparty))) - } - } - } - - @InitiatingFlow - private class VaultQueryFlow(val stx: SignedTransaction, val otherParty: Party) : FlowLogic>>() { - @Suspendable - override fun call(): List> { - val otherPartySession = initiateFlow(otherParty) - otherPartySession.send(stx) - // hold onto reference here to force checkpoint of vaultService and thus - // prove it is registered as a tokenizableService in the node - val vaultQuerySvc = serviceHub.vaultService - waitForLedgerCommit(stx.id) - return vaultQuerySvc.queryBy().states - } - } - - @InitiatingFlow(version = 2) - private class UpgradedFlow(val otherParty: Party, val otherPartySession: FlowSession? = null) : FlowLogic>() { - constructor(otherPartySession: FlowSession) : this(otherPartySession.counterparty, otherPartySession) - - @Suspendable - override fun call(): Pair { - val otherPartySession = this.otherPartySession ?: initiateFlow(otherParty) - val received = otherPartySession.receive().unwrap { it } - val otherFlowVersion = otherPartySession.getCounterpartyFlowInfo().flowVersion - return Pair(received, otherFlowVersion) - } - } - - private class SingleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic() { - @Suspendable - override fun call() { - val payload = otherPartySession.receive().unwrap { it } - subFlow(InlinedSendFlow(payload + payload, otherPartySession)) - } - } - - private class DoubleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic() { - @Suspendable - override fun call() { - subFlow(SingleInlinedSubFlow(otherPartySession)) - } - } - - private data class NonSerialisableData(val a: Int) - private class NonSerialisableFlowException(@Suppress("unused") val data: NonSerialisableData) : FlowException() - //endregion Helpers } + +private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, ""))) + +private inline fun > StartedNode<*>.getSingleFlow(): Pair> { + return smm.findStateMachines(P::class.java).single() +} + +private fun sanitise(message: SessionMessage) = when (message) { + is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(0), initiationEntropy = 0, appName = "") + is ExistingSessionMessage -> { + val payload = message.payload + message.copy( + recipientSessionId = SessionId(0), + payload = when (payload) { + is ConfirmSessionMessage -> payload.copy( + initiatedSessionId = SessionId(0), + initiatedFlowInfo = payload.initiatedFlowInfo.copy(appName = "") + ) + is ErrorSessionMessage -> payload.copy( + errorId = 0 + ) + else -> payload + } + ) + } +} + +private fun Observable.toSessionTransfers(): Observable { + return filter { it.getMessage().topic == FlowMessagingImpl.sessionTopic }.map { + val from = it.sender.id + val message = it.messageData.deserialize() + SessionTransfer(from, sanitise(message), it.recipients) + } +} + +private fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0)) + +private infix fun StartedNode.sent(message: SessionMessage): Pair = Pair(internals.id, message) +private infix fun Pair.to(node: StartedNode<*>): SessionTransfer = SessionTransfer(first, second, node.network.myAddress) + +private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) { + val isPayloadTransfer: Boolean get() = + message is ExistingSessionMessage && message.payload is DataSessionMessage || + message is InitialSessionMessage && message.firstPayload != null + override fun toString(): String = "$from sent $message to $to" +} + +private inline fun > StartedNode<*>.registerFlowFactory( + initiatingFlowClass: KClass>, + initiatedFlowVersion: Int = 1, + noinline flowFactory: (FlowSession) -> P): CordaFuture

{ + val observable = internalRegisterFlowFactory( + initiatingFlowClass.java, + InitiatedFlowFactory.CorDapp(initiatedFlowVersion, "", flowFactory), + P::class.java, + track = true) + return observable.toFuture() +} + +private fun sessionInit(clientFlowClass: KClass>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage { + return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) +} + +private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize())) + + +private val FlowLogic<*>.progressSteps: CordaFuture>> + get() { + return progressTracker!!.changes + .ofType(Change.Position::class.java) + .map { it.newStep } + .materialize() + .toList() + .toFuture() + } + +class ThrowingActionExecutor(private val exception: Exception, val delegate: ActionExecutor) : ActionExecutor { + var thrown = false + @Suspendable + override fun executeAction(fiber: FlowFiber, action: Action) { + if (thrown) { + delegate.executeAction(fiber, action) + } else { + thrown = true + throw exception + } + } +} + +@InitiatingFlow +private class WaitForOtherSideEndBeforeSendAndReceive(val otherParty: Party, + @Transient val receivedOtherFlowEnd: Semaphore) : FlowLogic() { + @Suspendable + override fun call() { + // Kick off the flow on the other side ... + val session = initiateFlow(otherParty) + session.send(1) + // ... then pause this one until it's received the session-end message from the other side + receivedOtherFlowEnd.acquire() + session.sendAndReceive(2) + } +} + +@InitiatingFlow +private open class SendFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic() { + init { + require(otherParties.isNotEmpty()) + } + + @Suspendable + override fun call(): FlowInfo { + val flowInfos = otherParties.map { + val session = initiateFlow(it) + session.send(payload) + session.getCounterpartyFlowInfo() + }.toList() + return flowInfos.first() + } +} + +// we need brand new class for a flow to fail, so here it is +@InitiatingFlow +private open class NeverRegisteredFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic() { + init { + require(otherParties.isNotEmpty()) + } + + @Suspendable + override fun call(): FlowInfo { + val flowInfos = otherParties.map { + val session = initiateFlow(it) + session.send(payload) + session.getCounterpartyFlowInfo() + }.toList() + return flowInfos.first() + } +} + +private object WaitingFlows { + @InitiatingFlow + class Waiter(val stx: SignedTransaction, val otherParty: Party) : FlowLogic() { + @Suspendable + override fun call(): SignedTransaction { + val otherPartySession = initiateFlow(otherParty) + otherPartySession.send(stx) + return waitForLedgerCommit(stx.id) + } + } + + class Committer(val otherPartySession: FlowSession, val throwException: (() -> Exception)? = null) : FlowLogic() { + @Suspendable + override fun call(): SignedTransaction { + val stx = otherPartySession.receive().unwrap { it } + if (throwException != null) throw throwException.invoke() + return subFlow(FinalityFlow(stx, setOf(otherPartySession.counterparty))) + } + } +} + +private class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic() { + @Transient + var flowStarted = false + + @Suspendable + override fun call() { + flowStarted = true + if (nonTerminating) { + Fiber.park() + } + } +} + +private class InitiatedReceiveFlow(val otherPartySession: FlowSession) : FlowLogic() { + object START_STEP : ProgressTracker.Step("Starting") + object RECEIVED_STEP : ProgressTracker.Step("Received") + + override val progressTracker: ProgressTracker = ProgressTracker(START_STEP, RECEIVED_STEP) + private var nonTerminating: Boolean = false + @Transient + var receivedPayloads: List = emptyList() + + @Suspendable + override fun call() { + progressTracker.currentStep = START_STEP + receivedPayloads = listOf(otherPartySession.receive().unwrap { it }) + progressTracker.currentStep = RECEIVED_STEP + if (nonTerminating) { + Fiber.park() + } + } + + fun nonTerminating(): InitiatedReceiveFlow { + nonTerminating = true + return this + } +} + +private class LazyServiceHubAccessFlow : FlowLogic() { + val lazyTime: Instant by lazy { serviceHub.clock.instant() } + @Suspendable + override fun call() = Unit +} + +private open class InitiatedSendFlow(val payload: Any, val otherPartySession: FlowSession) : FlowLogic() { + @Suspendable + override fun call() = otherPartySession.send(payload) +} + +private interface CustomInterface + +private class CustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty) + +@InitiatingFlow +private class IncorrectCustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty) + +@InitiatingFlow +private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic() { + object START_STEP : ProgressTracker.Step("Starting") + object RECEIVED_STEP : ProgressTracker.Step("Received") + + init { + require(otherParties.isNotEmpty()) + } + + override val progressTracker: ProgressTracker = ProgressTracker(START_STEP, RECEIVED_STEP) + private var nonTerminating: Boolean = false + @Transient + var receivedPayloads: List = emptyList() + + @Suspendable + override fun call() { + progressTracker.currentStep = START_STEP + receivedPayloads = otherParties.map { initiateFlow(it).receive().unwrap { it } } + progressTracker.currentStep = RECEIVED_STEP + if (nonTerminating) { + Fiber.park() + } + } + + fun nonTerminating(): ReceiveFlow { + nonTerminating = true + return this + } +} + +private class MyFlowException(override val message: String) : FlowException() { + override fun equals(other: Any?): Boolean = other is MyFlowException && other.message == this.message + override fun hashCode(): Int = message.hashCode() +} + +@InitiatingFlow +private class VaultQueryFlow(val stx: SignedTransaction, val otherParty: Party) : FlowLogic>>() { + @Suspendable + override fun call(): List> { + val otherPartySession = initiateFlow(otherParty) + otherPartySession.send(stx) + // hold onto reference here to force checkpoint of vaultService and thus + // prove it is registered as a tokenizableService in the node + val vaultQuerySvc = serviceHub.vaultService + waitForLedgerCommit(stx.id) + return vaultQuerySvc.queryBy().states + } +} + +@InitiatingFlow(version = 2) +private class UpgradedFlow(val otherParty: Party, val otherPartySession: FlowSession? = null) : FlowLogic>() { + constructor(otherPartySession: FlowSession) : this(otherPartySession.counterparty, otherPartySession) + + @Suspendable + override fun call(): Pair { + val otherPartySession = this.otherPartySession ?: initiateFlow(otherParty) + val received = otherPartySession.receive().unwrap { it } + val otherFlowVersion = otherPartySession.getCounterpartyFlowInfo().flowVersion + return Pair(received, otherFlowVersion) + } +} + +private class SingleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic() { + @Suspendable + override fun call() { + val payload = otherPartySession.receive().unwrap { it } + subFlow(InlinedSendFlow(payload + payload, otherPartySession)) + } +} + +private class DoubleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic() { + @Suspendable + override fun call() { + subFlow(SingleInlinedSubFlow(otherPartySession)) + } +} + +private data class NonSerialisableData(val a: Int) +private class NonSerialisableFlowException(@Suppress("unused") val data: NonSerialisableData) : FlowException() + +@InitiatingFlow +private class SendAndReceiveFlow(val otherParty: Party, val payload: Any, val otherPartySession: FlowSession? = null) : FlowLogic() { + constructor(otherPartySession: FlowSession, payload: Any) : this(otherPartySession.counterparty, payload, otherPartySession) + + @Suspendable + override fun call(): Any = (otherPartySession ?: initiateFlow(otherParty)).sendAndReceive(payload).unwrap { it } +} + +private class InlinedSendFlow(val payload: String, val otherPartySession: FlowSession) : FlowLogic() { + @Suspendable + override fun call() = otherPartySession.send(payload) +} + +@InitiatingFlow +private class PingPongFlow(val otherParty: Party, val payload: Long, val otherPartySession: FlowSession? = null) : FlowLogic() { + constructor(otherPartySession: FlowSession, payload: Long) : this(otherPartySession.counterparty, payload, otherPartySession) + + @Transient + var receivedPayload: Long? = null + @Transient + var receivedPayload2: Long? = null + + @Suspendable + override fun call() { + val session = otherPartySession ?: initiateFlow(otherParty) + receivedPayload = session.sendAndReceive(payload).unwrap { it } + receivedPayload2 = session.sendAndReceive(payload + 1).unwrap { it } + } +} + +private class ExceptionFlow(val exception: () -> E) : FlowLogic() { + object START_STEP : ProgressTracker.Step("Starting") + + override val progressTracker: ProgressTracker = ProgressTracker(START_STEP) + lateinit var exceptionThrown: E + + @Suspendable + override fun call(): Nothing { + progressTracker.currentStep = START_STEP + exceptionThrown = exception() + throw exceptionThrown + } +}