From 1c012f6403eb392778e16e94f11fd962b9f0039a Mon Sep 17 00:00:00 2001 From: Shams Asari <shams.asari@r3.com> Date: Mon, 12 Nov 2018 18:38:47 +0000 Subject: [PATCH] Back porting clean up of FlowFrameworkTests.kt made in ENT (#4218) --- .../net/corda/node/internal/AbstractNode.kt | 13 +- .../FlowFrameworkPersistenceTests.kt | 166 +++++ .../statemachine/FlowFrameworkTests.kt | 577 +++++------------- .../FlowFrameworkTripartyTests.kt | 178 ++++++ 4 files changed, 488 insertions(+), 446 deletions(-) create mode 100644 node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkPersistenceTests.kt create mode 100644 node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTripartyTests.kt diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index d48580b5e5..6ff5306a4a 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -9,9 +9,9 @@ import net.corda.confidential.SwapIdentitiesHandler import net.corda.core.CordaException import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext -import net.corda.core.crypto.internal.AliasPrivateKey import net.corda.core.crypto.DigitalSignature import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.internal.AliasPrivateKey import net.corda.core.crypto.newSecureRandom import net.corda.core.flows.* import net.corda.core.identity.AbstractParty @@ -122,14 +122,14 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration, cacheFactoryPrototype: BindableNamedCacheFactory, protected val versionInfo: VersionInfo, protected val flowManager: FlowManager, - protected val serverThread: AffinityExecutor.ServiceAffinityExecutor, - private val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() { + val serverThread: AffinityExecutor.ServiceAffinityExecutor, + val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() { protected abstract val log: Logger @Suppress("LeakingThis") private var tokenizableServices: MutableList<Any>? = mutableListOf(platformClock, this) - protected val metricRegistry = MetricRegistry() + val metricRegistry = MetricRegistry() protected val cacheFactory = cacheFactoryPrototype.bindWithConfig(configuration).bindWithMetrics(metricRegistry).tokenize() val monitoringService = MonitoringService(metricRegistry).tokenize() @@ -146,7 +146,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration, } } - protected val cordappLoader: CordappLoader = makeCordappLoader(configuration, versionInfo) + val cordappLoader: CordappLoader = makeCordappLoader(configuration, versionInfo) val schemaService = NodeSchemaService(cordappLoader.cordappSchemas).tokenize() val identityService = PersistentIdentityService(cacheFactory).tokenize() val database: CordaPersistence = createCordaPersistence( @@ -777,7 +777,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration, // Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because // the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with // the identity key. But the infrastructure to make that easy isn't here yet. - return BasicHSMKeyManagementService(cacheFactory,identityService, database, cryptoService) + return BasicHSMKeyManagementService(cacheFactory, identityService, database, cryptoService) } open fun stop() { @@ -1008,7 +1008,6 @@ class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogi private val _future = openFuture<FlowStateMachine<T>>() override val future: CordaFuture<FlowStateMachine<T>> get() = _future - } return startFlow(startFlowEvent) } diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkPersistenceTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkPersistenceTests.kt new file mode 100644 index 0000000000..8a9156af4a --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkPersistenceTests.kt @@ -0,0 +1,166 @@ +package net.corda.node.services.statemachine + +import net.corda.core.crypto.random63BitValue +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.registerCordappFlowFactory +import net.corda.core.identity.Party +import net.corda.core.utilities.getOrThrow +import net.corda.node.services.persistence.checkpoints +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.core.CHARLIE_NAME +import net.corda.testing.core.singleIdentity +import net.corda.testing.internal.LogHelper +import net.corda.testing.node.InMemoryMessagingNetwork +import net.corda.testing.node.internal.* +import org.assertj.core.api.Assertions.assertThat +import org.junit.After +import org.junit.Before +import org.junit.Ignore +import org.junit.Test +import rx.Observable +import java.util.* +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class FlowFrameworkPersistenceTests { + companion object { + init { + LogHelper.setLevel("+net.corda.flow") + } + } + + private lateinit var mockNet: InternalMockNetwork + private val receivedSessionMessages = ArrayList<SessionTransfer>() + private lateinit var aliceNode: TestStartedNode + private lateinit var bobNode: TestStartedNode + private lateinit var notaryIdentity: Party + private lateinit var alice: Party + private lateinit var bob: Party + private lateinit var aliceFlowManager: MockNodeFlowManager + private lateinit var bobFlowManager: MockNodeFlowManager + + @Before + fun start() { + mockNet = InternalMockNetwork( + cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"), + servicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin() + ) + aliceFlowManager = MockNodeFlowManager() + bobFlowManager = MockNodeFlowManager() + + aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME, flowManager = aliceFlowManager)) + bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME, flowManager = bobFlowManager)) + + receivedSessionMessagesObservable().forEach { receivedSessionMessages += it } + + // Extract identities + alice = aliceNode.info.singleIdentity() + bob = bobNode.info.singleIdentity() + notaryIdentity = mockNet.defaultNotaryIdentity + } + + @After + fun cleanUp() { + mockNet.stopNodes() + receivedSessionMessages.clear() + } + + @Test + fun `newly added flow is preserved on restart`() { + aliceNode.services.startFlow(NoOpFlow(nonTerminating = true)) + aliceNode.internals.acceptableLiveFiberCountOnStop = 1 + val restoredFlow = aliceNode.restartAndGetRestoredFlow<NoOpFlow>() + assertThat(restoredFlow.flowStarted).isTrue() + } + + @Test + fun `flow restarted just after receiving payload`() { + bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it) + .nonTerminating() } + aliceNode.services.startFlow(SendFlow("Hello", bob)) + + // We push through just enough messages to get only the payload sent + bobNode.pumpReceive() + bobNode.internals.disableDBCloseOnStop() + bobNode.internals.acceptableLiveFiberCountOnStop = 1 + bobNode.dispose() + mockNet.runNetwork() + val restoredFlow = bobNode.restartAndGetRestoredFlow<InitiatedReceiveFlow>() + assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello") + } + + @Test + fun `flow loaded from checkpoint will respond to messages from before start`() { + aliceNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } + bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow + val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>() + assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello") + } + + @Ignore("Some changes in startup order make this test's assumptions fail.") + @Test + fun `flow with send will resend on interrupted restart`() { + val payload = random63BitValue() + val payload2 = random63BitValue() + + var sentCount = 0 + mockNet.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ } + val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) + val secondFlow = charlieNode.registerCordappFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) } + mockNet.runNetwork() + val charlie = charlieNode.info.singleIdentity() + + // Kick off first send and receive + bobNode.services.startFlow(PingPongFlow(charlie, payload)) + bobNode.database.transaction { + assertEquals(1, bobNode.internals.checkpointStorage.checkpoints().size) + } + // Make sure the add() has finished initial processing. + bobNode.internals.disableDBCloseOnStop() + // Restart node and thus reload the checkpoint and resend the message with same UUID + bobNode.dispose() + bobNode.database.transaction { + assertEquals(1, bobNode.internals.checkpointStorage.checkpoints().size) // confirm checkpoint + bobNode.services.networkMapCache.clearNetworkMapCache() + } + val node2b = mockNet.createNode(InternalMockNodeParameters(bobNode.internals.id)) + bobNode.internals.manuallyCloseDB() + val (firstAgain, fut1) = node2b.getSingleFlow<PingPongFlow>() + // Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync. + mockNet.runNetwork() + fut1.getOrThrow() + + val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer } + // Check flows completed cleanly and didn't get out of phase + assertEquals(4, receivedCount, "Flow should have exchanged 4 unique messages")// Two messages each way + // can't give a precise value as every addMessageHandler re-runs the undelivered messages + assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") + node2b.database.transaction { + assertEquals(0, node2b.internals.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended") + } + charlieNode.database.transaction { + assertEquals(0, charlieNode.internals.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended") + } + assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3") + assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3") + assertEquals(payload, secondFlow.getOrThrow().receivedPayload, "Received payload does not match the (restarted) first value on Node 2") + assertEquals(payload + 1, secondFlow.getOrThrow().receivedPayload2, "Received payload does not match the expected second value on Node 2") + } + + //////////////////////////////////////////////////////////////////////////////////////////////////////////// + //region Helpers + + private inline fun <reified P : FlowLogic<*>> TestStartedNode.restartAndGetRestoredFlow(): P { + val newNode = mockNet.restartNode(this) + newNode.internals.acceptableLiveFiberCountOnStop = 1 + mockNet.runNetwork() + return newNode.getSingleFlow<P>().first + } + + private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> { + return mockNet.messagingNetwork.receivedMessages.toSessionTransfers() + } + + //endregion Helpers +} diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index 53ec7d9f2a..23b919eb66 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -40,16 +40,13 @@ import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType import org.junit.After import org.junit.Before -import org.junit.Ignore import org.junit.Test import rx.Notification import rx.Observable import java.time.Instant import java.util.* import kotlin.reflect.KClass -import kotlin.test.assertEquals import kotlin.test.assertFailsWith -import kotlin.test.assertTrue class FlowFrameworkTests { companion object { @@ -449,320 +446,142 @@ class FlowFrameworkTests { private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) - private fun TestStartedNode.sendSessionMessage(message: SessionMessage, destination: Party) { - services.networkService.apply { - val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList())) - send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address) - } - } - private fun assertSessionTransfers(vararg expected: SessionTransfer) { assertThat(receivedSessionMessages).containsExactly(*expected) } - //endregion Helpers -} + private val FlowLogic<*>.progressSteps: CordaFuture<List<Notification<ProgressTracker.Step>>> + get() { + return progressTracker!!.changes + .ofType(Change.Position::class.java) + .map { it.newStep } + .materialize() + .toList() + .toFuture() + } -class FlowFrameworkTripartyTests { + @InitiatingFlow + private class WaitForOtherSideEndBeforeSendAndReceive(val otherParty: Party, + @Transient val receivedOtherFlowEnd: Semaphore) : FlowLogic<Unit>() { + @Suspendable + override fun call() { + // Kick off the flow on the other side ... + val session = initiateFlow(otherParty) + session.send(1) + // ... then pause this one until it's received the session-end message from the other side + receivedOtherFlowEnd.acquire() + session.sendAndReceive<Int>(2) + } + } - companion object { + // we need brand new class for a flow to fail, so here it is + @InitiatingFlow + private open class NeverRegisteredFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic<FlowInfo>() { init { - LogHelper.setLevel("+net.corda.flow") + require(otherParties.isNotEmpty()) } - private lateinit var mockNet: InternalMockNetwork - private lateinit var aliceNode: TestStartedNode - private lateinit var bobNode: TestStartedNode - private lateinit var charlieNode: TestStartedNode - private lateinit var alice: Party - private lateinit var bob: Party - private lateinit var charlie: Party - private lateinit var notaryIdentity: Party - private val receivedSessionMessages = ArrayList<SessionTransfer>() - } - - @Before - fun setUpGlobalMockNet() { - mockNet = InternalMockNetwork( - cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"), - servicePeerAllocationStrategy = RoundRobin() - ) - - aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) - bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME)) - charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) - - - // Extract identities - alice = aliceNode.info.singleIdentity() - bob = bobNode.info.singleIdentity() - charlie = charlieNode.info.singleIdentity() - notaryIdentity = mockNet.defaultNotaryIdentity - - receivedSessionMessagesObservable().forEach { receivedSessionMessages += it } - } - - @After - fun cleanUp() { - mockNet.stopNodes() - receivedSessionMessages.clear() - } - - private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> { - return mockNet.messagingNetwork.receivedMessages.toSessionTransfers() - } - - @Test - fun `sending to multiple parties`() { - bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } - charlieNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } - val payload = "Hello World" - aliceNode.services.startFlow(SendFlow(payload, bob, charlie)) - mockNet.runNetwork() - bobNode.internals.acceptableLiveFiberCountOnStop = 1 - charlieNode.internals.acceptableLiveFiberCountOnStop = 1 - val bobFlow = bobNode.getSingleFlow<InitiatedReceiveFlow>().first - val charlieFlow = charlieNode.getSingleFlow<InitiatedReceiveFlow>().first - assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload) - assertThat(charlieFlow.receivedPayloads[0]).isEqualTo(payload) - - assertSessionTransfers(bobNode, - aliceNode sent sessionInit(SendFlow::class, payload = payload) to bobNode, - bobNode sent sessionConfirm() to aliceNode, - aliceNode sent normalEnd to bobNode - //There's no session end from the other flows as they're manually suspended - ) - - assertSessionTransfers(charlieNode, - aliceNode sent sessionInit(SendFlow::class, payload = payload) to charlieNode, - charlieNode sent sessionConfirm() to aliceNode, - aliceNode sent normalEnd to charlieNode - //There's no session end from the other flows as they're manually suspended - ) - } - - @Test - fun `receiving from multiple parties`() { - val bobPayload = "Test 1" - val charliePayload = "Test 2" - bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) } - charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) } - val multiReceiveFlow = ReceiveFlow(bob, charlie).nonTerminating() - aliceNode.services.startFlow(multiReceiveFlow) - aliceNode.internals.acceptableLiveFiberCountOnStop = 1 - mockNet.runNetwork() - assertThat(multiReceiveFlow.receivedPayloads[0]).isEqualTo(bobPayload) - assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(charliePayload) - - assertSessionTransfers(bobNode, - aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, - bobNode sent sessionConfirm() to aliceNode, - bobNode sent sessionData(bobPayload) to aliceNode, - bobNode sent normalEnd to aliceNode - ) - - assertSessionTransfers(charlieNode, - aliceNode sent sessionInit(ReceiveFlow::class) to charlieNode, - charlieNode sent sessionConfirm() to aliceNode, - charlieNode sent sessionData(charliePayload) to aliceNode, - charlieNode sent normalEnd to aliceNode - ) - } - - @Test - fun `FlowException only propagated to parent`() { - charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } } - bobNode.registerCordappFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) } - val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob)) - mockNet.runNetwork() - assertThatExceptionOfType(UnexpectedFlowEndException::class.java) - .isThrownBy { receivingFiber.resultFuture.getOrThrow() } - } - - @Test - fun `FlowException thrown and there is a 3rd unrelated party flow`() { - // Bob will send its payload and then block waiting for the receive from Alice. Meanwhile Alice will move - // onto Charlie which will throw the exception - val node2Fiber = bobNode - .registerCordappFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") } - .map { it.stateMachine } - charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } } - - val aliceFiber = aliceNode.services.startFlow(ReceiveFlow(bob, charlie)) as FlowStateMachineImpl - mockNet.runNetwork() - - // Alice will terminate with the error it received from Charlie but it won't propagate that to Bob (as it's - // not relevant to it) but it will end its session with it - assertThatExceptionOfType(MyFlowException::class.java).isThrownBy { - aliceFiber.resultFuture.getOrThrow() - } - val bobResultFuture = node2Fiber.getOrThrow().resultFuture - assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { - bobResultFuture.getOrThrow() - } - - assertSessionTransfers(bobNode, - aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, - bobNode sent sessionConfirm() to aliceNode, - bobNode sent sessionData("Hello") to aliceNode, - aliceNode sent errorMessage() to bobNode - ) - } - - private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) - - private fun assertSessionTransfers(node: TestStartedNode, vararg expected: SessionTransfer): List<SessionTransfer> { - val actualForNode = receivedSessionMessages.filter { it.from == node.internals.id || it.to == node.network.myAddress } - assertThat(actualForNode).containsExactly(*expected) - return actualForNode - } - -} - -class FlowFrameworkPersistenceTests { - companion object { - init { - LogHelper.setLevel("+net.corda.flow") + @Suspendable + override fun call(): FlowInfo { + val flowInfos = otherParties.map { + val session = initiateFlow(it) + session.send(payload) + session.getCounterpartyFlowInfo() + }.toList() + return flowInfos.first() } } - private lateinit var mockNet: InternalMockNetwork - private val receivedSessionMessages = ArrayList<SessionTransfer>() - private lateinit var aliceNode: TestStartedNode - private lateinit var bobNode: TestStartedNode - private lateinit var notaryIdentity: Party - private lateinit var alice: Party - private lateinit var bob: Party - private lateinit var aliceFlowManager: MockNodeFlowManager - private lateinit var bobFlowManager: MockNodeFlowManager - - @Before - fun start() { - mockNet = InternalMockNetwork( - cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"), - servicePeerAllocationStrategy = RoundRobin() - ) - aliceFlowManager = MockNodeFlowManager() - bobFlowManager = MockNodeFlowManager() - - aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME, flowManager = aliceFlowManager)) - bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME, flowManager = bobFlowManager)) - - receivedSessionMessagesObservable().forEach { receivedSessionMessages += it } - - // Extract identities - alice = aliceNode.info.singleIdentity() - bob = bobNode.info.singleIdentity() - notaryIdentity = mockNet.defaultNotaryIdentity - } - - @After - fun cleanUp() { - mockNet.stopNodes() - receivedSessionMessages.clear() - } - - @Test - fun `newly added flow is preserved on restart`() { - aliceNode.services.startFlow(NoOpFlow(nonTerminating = true)) - aliceNode.internals.acceptableLiveFiberCountOnStop = 1 - val restoredFlow = aliceNode.restartAndGetRestoredFlow<NoOpFlow>() - assertThat(restoredFlow.flowStarted).isTrue() - } - - @Test - fun `flow restarted just after receiving payload`() { - bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() } - aliceNode.services.startFlow(SendFlow("Hello", bob)) - - // We push through just enough messages to get only the payload sent - bobNode.pumpReceive() - bobNode.internals.disableDBCloseOnStop() - bobNode.internals.acceptableLiveFiberCountOnStop = 1 - bobNode.dispose() - mockNet.runNetwork() - val restoredFlow = bobNode.restartAndGetRestoredFlow<InitiatedReceiveFlow>() - assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello") - } - - @Test - fun `flow loaded from checkpoint will respond to messages from before start`() { - aliceNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } - bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow - val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>() - assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello") - } - - @Ignore("Some changes in startup order make this test's assumptions fail.") - @Test - fun `flow with send will resend on interrupted restart`() { - val payload = random63BitValue() - val payload2 = random63BitValue() - - var sentCount = 0 - mockNet.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ } - val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) - val secondFlow = charlieNode.registerCordappFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) } - mockNet.runNetwork() - val charlie = charlieNode.info.singleIdentity() - - // Kick off first send and receive - bobNode.services.startFlow(PingPongFlow(charlie, payload)) - bobNode.database.transaction { - assertEquals(1, bobNode.internals.checkpointStorage.checkpoints().size) + private object WaitingFlows { + @InitiatingFlow + class Waiter(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<SignedTransaction>() { + @Suspendable + override fun call(): SignedTransaction { + val otherPartySession = initiateFlow(otherParty) + otherPartySession.send(stx) + return waitForLedgerCommit(stx.id) + } } - // Make sure the add() has finished initial processing. - bobNode.internals.disableDBCloseOnStop() - // Restart node and thus reload the checkpoint and resend the message with same UUID - bobNode.dispose() - bobNode.database.transaction { - assertEquals(1, bobNode.internals.checkpointStorage.checkpoints().size) // confirm checkpoint - bobNode.services.networkMapCache.clearNetworkMapCache() - } - val node2b = mockNet.createNode(InternalMockNodeParameters(bobNode.internals.id)) - bobNode.internals.manuallyCloseDB() - val (firstAgain, fut1) = node2b.getSingleFlow<PingPongFlow>() - // Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync. - mockNet.runNetwork() - fut1.getOrThrow() - val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer } - // Check flows completed cleanly and didn't get out of phase - assertEquals(4, receivedCount, "Flow should have exchanged 4 unique messages")// Two messages each way - // can't give a precise value as every addMessageHandler re-runs the undelivered messages - assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") - node2b.database.transaction { - assertEquals(0, node2b.internals.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended") + class Committer(val otherPartySession: FlowSession, val throwException: (() -> Exception)? = null) : FlowLogic<SignedTransaction>() { + @Suspendable + override fun call(): SignedTransaction { + val stx = otherPartySession.receive<SignedTransaction>().unwrap { it } + if (throwException != null) throw throwException.invoke() + return subFlow(FinalityFlow(stx, setOf(otherPartySession.counterparty))) + } } - charlieNode.database.transaction { - assertEquals(0, charlieNode.internals.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended") - } - assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3") - assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3") - assertEquals(payload, secondFlow.getOrThrow().receivedPayload, "Received payload does not match the (restarted) first value on Node 2") - assertEquals(payload + 1, secondFlow.getOrThrow().receivedPayload2, "Received payload does not match the expected second value on Node 2") } - //////////////////////////////////////////////////////////////////////////////////////////////////////////// - //region Helpers - - private inline fun <reified P : FlowLogic<*>> TestStartedNode.restartAndGetRestoredFlow(): P { - val newNode = mockNet.restartNode(this) - newNode.internals.acceptableLiveFiberCountOnStop = 1 - mockNet.runNetwork() - return newNode.getSingleFlow<P>().first + private class LazyServiceHubAccessFlow : FlowLogic<Unit>() { + val lazyTime: Instant by lazy { serviceHub.clock.instant() } + @Suspendable + override fun call() = Unit } - private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> { - return mockNet.messagingNetwork.receivedMessages.toSessionTransfers() + private interface CustomInterface + + private class CustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty) + + @InitiatingFlow + private class IncorrectCustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty) + + @InitiatingFlow + private class VaultQueryFlow(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<List<StateAndRef<ContractState>>>() { + @Suspendable + override fun call(): List<StateAndRef<ContractState>> { + val otherPartySession = initiateFlow(otherParty) + otherPartySession.send(stx) + // hold onto reference here to force checkpoint of vaultService and thus + // prove it is registered as a tokenizableService in the node + val vaultQuerySvc = serviceHub.vaultService + waitForLedgerCommit(stx.id) + return vaultQuerySvc.queryBy<ContractState>().states + } + } + + @InitiatingFlow(version = 2) + private class UpgradedFlow(val otherParty: Party, val otherPartySession: FlowSession? = null) : FlowLogic<Pair<Any, Int>>() { + constructor(otherPartySession: FlowSession) : this(otherPartySession.counterparty, otherPartySession) + + @Suspendable + override fun call(): Pair<Any, Int> { + val otherPartySession = this.otherPartySession ?: initiateFlow(otherParty) + val received = otherPartySession.receive<Any>().unwrap { it } + val otherFlowVersion = otherPartySession.getCounterpartyFlowInfo().flowVersion + return Pair(received, otherFlowVersion) + } + } + + private class SingleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() { + @Suspendable + override fun call() { + val payload = otherPartySession.receive<String>().unwrap { it } + subFlow(InlinedSendFlow(payload + payload, otherPartySession)) + } + } + + private class DoubleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() { + @Suspendable + override fun call() { + subFlow(SingleInlinedSubFlow(otherPartySession)) + } + } + + private data class NonSerialisableData(val a: Int) + private class NonSerialisableFlowException(@Suppress("unused") val data: NonSerialisableData) : FlowException() + + private class InlinedSendFlow(val payload: String, val otherPartySession: FlowSession) : FlowLogic<Unit>() { + @Suspendable + override fun call() = otherPartySession.send(payload) } //endregion Helpers } -private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, ""))) +internal fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, ""))) -private inline fun <reified P : FlowLogic<*>> TestStartedNode.getSingleFlow(): Pair<P, CordaFuture<*>> { +internal inline fun <reified P : FlowLogic<*>> TestStartedNode.getSingleFlow(): Pair<P, CordaFuture<*>> { return smm.findStateMachines(P::class.java).single() } @@ -786,7 +605,7 @@ private fun sanitise(message: SessionMessage) = when (message) { } } -private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> { +internal fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> { return filter { it.getMessage().topic == FlowMessagingImpl.sessionTopic }.map { val from = it.sender.id val message = it.messageData.deserialize<SessionMessage>() @@ -794,12 +613,19 @@ private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<Session } } -private fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0)) +internal fun TestStartedNode.sendSessionMessage(message: SessionMessage, destination: Party) { + services.networkService.apply { + val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList())) + send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address) + } +} -private infix fun TestStartedNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message) -private infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer = SessionTransfer(first, second, node.network.myAddress) +internal fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0)) -private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) { +internal infix fun TestStartedNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message) +internal infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer = SessionTransfer(first, second, node.network.myAddress) + +internal data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) { val isPayloadTransfer: Boolean get() = message is ExistingSessionMessage && message.payload is DataSessionMessage || @@ -808,40 +634,14 @@ private data class SessionTransfer(val from: Int, val message: SessionMessage, v override fun toString(): String = "$from sent $message to $to" } - -private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage { +internal fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage { return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) } -private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize())) - - -private val FlowLogic<*>.progressSteps: CordaFuture<List<Notification<ProgressTracker.Step>>> - get() { - return progressTracker!!.changes - .ofType(Change.Position::class.java) - .map { it.newStep } - .materialize() - .toList() - .toFuture() - } +internal fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize())) @InitiatingFlow -private class WaitForOtherSideEndBeforeSendAndReceive(val otherParty: Party, - @Transient val receivedOtherFlowEnd: Semaphore) : FlowLogic<Unit>() { - @Suspendable - override fun call() { - // Kick off the flow on the other side ... - val session = initiateFlow(otherParty) - session.send(1) - // ... then pause this one until it's received the session-end message from the other side - receivedOtherFlowEnd.acquire() - session.sendAndReceive<Int>(2) - } -} - -@InitiatingFlow -private open class SendFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic<FlowInfo>() { +internal open class SendFlow(private val payload: Any, private vararg val otherParties: Party) : FlowLogic<FlowInfo>() { init { require(otherParties.isNotEmpty()) } @@ -857,46 +657,7 @@ private open class SendFlow(val payload: Any, vararg val otherParties: Party) : } } -// we need brand new class for a flow to fail, so here it is -@InitiatingFlow -private open class NeverRegisteredFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic<FlowInfo>() { - init { - require(otherParties.isNotEmpty()) - } - - @Suspendable - override fun call(): FlowInfo { - val flowInfos = otherParties.map { - val session = initiateFlow(it) - session.send(payload) - session.getCounterpartyFlowInfo() - }.toList() - return flowInfos.first() - } -} - -private object WaitingFlows { - @InitiatingFlow - class Waiter(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<SignedTransaction>() { - @Suspendable - override fun call(): SignedTransaction { - val otherPartySession = initiateFlow(otherParty) - otherPartySession.send(stx) - return waitForLedgerCommit(stx.id) - } - } - - class Committer(val otherPartySession: FlowSession, val throwException: (() -> Exception)? = null) : FlowLogic<SignedTransaction>() { - @Suspendable - override fun call(): SignedTransaction { - val stx = otherPartySession.receive<SignedTransaction>().unwrap { it } - if (throwException != null) throw throwException.invoke() - return subFlow(FinalityFlow(stx, setOf(otherPartySession.counterparty))) - } - } -} - -private class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic<Unit>() { +internal class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic<Unit>() { @Transient var flowStarted = false @@ -909,7 +670,7 @@ private class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic<Unit>() } } -private class InitiatedReceiveFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() { +internal class InitiatedReceiveFlow(private val otherPartySession: FlowSession) : FlowLogic<Unit>() { object START_STEP : ProgressTracker.Step("Starting") object RECEIVED_STEP : ProgressTracker.Step("Received") @@ -934,26 +695,13 @@ private class InitiatedReceiveFlow(val otherPartySession: FlowSession) : FlowLog } } -private class LazyServiceHubAccessFlow : FlowLogic<Unit>() { - val lazyTime: Instant by lazy { serviceHub.clock.instant() } - @Suspendable - override fun call() = Unit -} - -private open class InitiatedSendFlow(val payload: Any, val otherPartySession: FlowSession) : FlowLogic<Unit>() { +internal open class InitiatedSendFlow(private val payload: Any, private val otherPartySession: FlowSession) : FlowLogic<Unit>() { @Suspendable override fun call() = otherPartySession.send(payload) } -private interface CustomInterface - -private class CustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty) - @InitiatingFlow -private class IncorrectCustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty) - -@InitiatingFlow -private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() { +internal class ReceiveFlow(private vararg val otherParties: Party) : FlowLogic<Unit>() { object START_STEP : ProgressTracker.Step("Starting") object RECEIVED_STEP : ProgressTracker.Step("Received") @@ -982,72 +730,23 @@ private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() { } } -private class MyFlowException(override val message: String) : FlowException() { +internal class MyFlowException(override val message: String) : FlowException() { override fun equals(other: Any?): Boolean = other is MyFlowException && other.message == this.message override fun hashCode(): Int = message.hashCode() } @InitiatingFlow -private class VaultQueryFlow(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<List<StateAndRef<ContractState>>>() { - @Suspendable - override fun call(): List<StateAndRef<ContractState>> { - val otherPartySession = initiateFlow(otherParty) - otherPartySession.send(stx) - // hold onto reference here to force checkpoint of vaultService and thus - // prove it is registered as a tokenizableService in the node - val vaultQuerySvc = serviceHub.vaultService - waitForLedgerCommit(stx.id) - return vaultQuerySvc.queryBy<ContractState>().states - } -} - -@InitiatingFlow(version = 2) -private class UpgradedFlow(val otherParty: Party, val otherPartySession: FlowSession? = null) : FlowLogic<Pair<Any, Int>>() { - constructor(otherPartySession: FlowSession) : this(otherPartySession.counterparty, otherPartySession) - - @Suspendable - override fun call(): Pair<Any, Int> { - val otherPartySession = this.otherPartySession ?: initiateFlow(otherParty) - val received = otherPartySession.receive<Any>().unwrap { it } - val otherFlowVersion = otherPartySession.getCounterpartyFlowInfo().flowVersion - return Pair(received, otherFlowVersion) - } -} - -private class SingleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() { - @Suspendable - override fun call() { - val payload = otherPartySession.receive<String>().unwrap { it } - subFlow(InlinedSendFlow(payload + payload, otherPartySession)) - } -} - -private class DoubleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() { - @Suspendable - override fun call() { - subFlow(SingleInlinedSubFlow(otherPartySession)) - } -} - -private data class NonSerialisableData(val a: Int) -private class NonSerialisableFlowException(@Suppress("unused") val data: NonSerialisableData) : FlowException() - -@InitiatingFlow -private class SendAndReceiveFlow(val otherParty: Party, val payload: Any, val otherPartySession: FlowSession? = null) : FlowLogic<Any>() { +internal class SendAndReceiveFlow(private val otherParty: Party, private val payload: Any, private val otherPartySession: FlowSession? = null) : FlowLogic<Any>() { constructor(otherPartySession: FlowSession, payload: Any) : this(otherPartySession.counterparty, payload, otherPartySession) @Suspendable - override fun call(): Any = (otherPartySession - ?: initiateFlow(otherParty)).sendAndReceive<Any>(payload).unwrap { it } -} - -private class InlinedSendFlow(val payload: String, val otherPartySession: FlowSession) : FlowLogic<Unit>() { - @Suspendable - override fun call() = otherPartySession.send(payload) + override fun call(): Any { + return (otherPartySession ?: initiateFlow(otherParty)).sendAndReceive<Any>(payload).unwrap { it } + } } @InitiatingFlow -private class PingPongFlow(val otherParty: Party, val payload: Long, val otherPartySession: FlowSession? = null) : FlowLogic<Unit>() { +internal class PingPongFlow(private val otherParty: Party, private val payload: Long, private val otherPartySession: FlowSession? = null) : FlowLogic<Unit>() { constructor(otherPartySession: FlowSession, payload: Long) : this(otherPartySession.counterparty, payload, otherPartySession) @Transient @@ -1063,7 +762,7 @@ private class PingPongFlow(val otherParty: Party, val payload: Long, val otherPa } } -private class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic<Nothing>() { +internal class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic<Nothing>() { object START_STEP : ProgressTracker.Step("Starting") override val progressTracker: ProgressTracker = ProgressTracker(START_STEP) @@ -1075,4 +774,4 @@ private class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic<N exceptionThrown = exception() throw exceptionThrown } -} \ No newline at end of file +} diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTripartyTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTripartyTests.kt new file mode 100644 index 0000000000..eb7be9531f --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTripartyTests.kt @@ -0,0 +1,178 @@ +package net.corda.node.services.statemachine + +import net.corda.core.flows.UnexpectedFlowEndException +import net.corda.core.flows.registerCordappFlowFactory +import net.corda.core.identity.Party +import net.corda.core.internal.concurrent.map +import net.corda.core.utilities.getOrThrow +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.core.CHARLIE_NAME +import net.corda.testing.core.singleIdentity +import net.corda.testing.internal.LogHelper +import net.corda.testing.node.InMemoryMessagingNetwork +import net.corda.testing.node.internal.* +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.AssertionsForClassTypes +import org.junit.After +import org.junit.Before +import org.junit.Test +import rx.Observable +import java.util.* + +class FlowFrameworkTripartyTests { + companion object { + init { + LogHelper.setLevel("+net.corda.flow") + } + + private lateinit var mockNet: InternalMockNetwork + private lateinit var aliceNode: TestStartedNode + private lateinit var bobNode: TestStartedNode + private lateinit var charlieNode: TestStartedNode + private lateinit var alice: Party + private lateinit var bob: Party + private lateinit var charlie: Party + private lateinit var notaryIdentity: Party + private val receivedSessionMessages = ArrayList<SessionTransfer>() + } + + @Before + fun setUpGlobalMockNet() { + mockNet = InternalMockNetwork( + cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"), + servicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin() + ) + + aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) + bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME)) + charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME)) + + + // Extract identities + alice = aliceNode.info.singleIdentity() + bob = bobNode.info.singleIdentity() + charlie = charlieNode.info.singleIdentity() + notaryIdentity = mockNet.defaultNotaryIdentity + + receivedSessionMessagesObservable().forEach { receivedSessionMessages += it } + } + + @After + fun cleanUp() { + mockNet.stopNodes() + receivedSessionMessages.clear() + } + + private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> { + return mockNet.messagingNetwork.receivedMessages.toSessionTransfers() + } + + @Test + fun `sending to multiple parties`() { + bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it) + .nonTerminating() } + charlieNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it) + .nonTerminating() } + val payload = "Hello World" + aliceNode.services.startFlow(SendFlow(payload, bob, charlie)) + mockNet.runNetwork() + bobNode.internals.acceptableLiveFiberCountOnStop = 1 + charlieNode.internals.acceptableLiveFiberCountOnStop = 1 + val bobFlow = bobNode.getSingleFlow<InitiatedReceiveFlow>().first + val charlieFlow = charlieNode.getSingleFlow<InitiatedReceiveFlow>().first + assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload) + assertThat(charlieFlow.receivedPayloads[0]).isEqualTo(payload) + + assertSessionTransfers(bobNode, + aliceNode sent sessionInit(SendFlow::class, payload = payload) to bobNode, + bobNode sent sessionConfirm() to aliceNode, + aliceNode sent normalEnd to bobNode + //There's no session end from the other flows as they're manually suspended + ) + + assertSessionTransfers(charlieNode, + aliceNode sent sessionInit(SendFlow::class, payload = payload) to charlieNode, + charlieNode sent sessionConfirm() to aliceNode, + aliceNode sent normalEnd to charlieNode + //There's no session end from the other flows as they're manually suspended + ) + } + + @Test + fun `receiving from multiple parties`() { + val bobPayload = "Test 1" + val charliePayload = "Test 2" + bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) } + charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) } + val multiReceiveFlow = ReceiveFlow(bob, charlie).nonTerminating() + aliceNode.services.startFlow(multiReceiveFlow) + aliceNode.internals.acceptableLiveFiberCountOnStop = 1 + mockNet.runNetwork() + assertThat(multiReceiveFlow.receivedPayloads[0]).isEqualTo(bobPayload) + assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(charliePayload) + + assertSessionTransfers(bobNode, + aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, + bobNode sent sessionConfirm() to aliceNode, + bobNode sent sessionData(bobPayload) to aliceNode, + bobNode sent normalEnd to aliceNode + ) + + assertSessionTransfers(charlieNode, + aliceNode sent sessionInit(ReceiveFlow::class) to charlieNode, + charlieNode sent sessionConfirm() to aliceNode, + charlieNode sent sessionData(charliePayload) to aliceNode, + charlieNode sent normalEnd to aliceNode + ) + } + + @Test + fun `FlowException only propagated to parent`() { + charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } } + bobNode.registerCordappFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) } + val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob)) + mockNet.runNetwork() + AssertionsForClassTypes.assertThatExceptionOfType(UnexpectedFlowEndException::class.java) + .isThrownBy { receivingFiber.resultFuture.getOrThrow() } + } + + @Test + fun `FlowException thrown and there is a 3rd unrelated party flow`() { + // Bob will send its payload and then block waiting for the receive from Alice. Meanwhile Alice will move + // onto Charlie which will throw the exception + val node2Fiber = bobNode + .registerCordappFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") } + .map { it.stateMachine } + charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } } + + val aliceFiber = aliceNode.services.startFlow(ReceiveFlow(bob, charlie)) as FlowStateMachineImpl + mockNet.runNetwork() + + // Alice will terminate with the error it received from Charlie but it won't propagate that to Bob (as it's + // not relevant to it) but it will end its session with it + AssertionsForClassTypes.assertThatExceptionOfType(MyFlowException::class.java) + .isThrownBy { + aliceFiber.resultFuture.getOrThrow() + } + val bobResultFuture = node2Fiber.getOrThrow().resultFuture + AssertionsForClassTypes.assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { + bobResultFuture.getOrThrow() + } + + assertSessionTransfers(bobNode, + aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, + bobNode sent sessionConfirm() to aliceNode, + bobNode sent sessionData("Hello") to aliceNode, + aliceNode sent errorMessage() to bobNode + ) + } + + private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) + + private fun assertSessionTransfers(node: TestStartedNode, vararg expected: SessionTransfer): List<SessionTransfer> { + val actualForNode = receivedSessionMessages.filter { it.from == node.internals.id || it.to == node.network.myAddress } + assertThat(actualForNode).containsExactly(*expected) + return actualForNode + } +}