From 1c012f6403eb392778e16e94f11fd962b9f0039a Mon Sep 17 00:00:00 2001
From: Shams Asari <shams.asari@r3.com>
Date: Mon, 12 Nov 2018 18:38:47 +0000
Subject: [PATCH] Back porting clean up of FlowFrameworkTests.kt made in ENT
 (#4218)

---
 .../net/corda/node/internal/AbstractNode.kt   |  13 +-
 .../FlowFrameworkPersistenceTests.kt          | 166 +++++
 .../statemachine/FlowFrameworkTests.kt        | 577 +++++-------------
 .../FlowFrameworkTripartyTests.kt             | 178 ++++++
 4 files changed, 488 insertions(+), 446 deletions(-)
 create mode 100644 node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkPersistenceTests.kt
 create mode 100644 node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTripartyTests.kt

diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
index d48580b5e5..6ff5306a4a 100644
--- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
+++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
@@ -9,9 +9,9 @@ import net.corda.confidential.SwapIdentitiesHandler
 import net.corda.core.CordaException
 import net.corda.core.concurrent.CordaFuture
 import net.corda.core.context.InvocationContext
-import net.corda.core.crypto.internal.AliasPrivateKey
 import net.corda.core.crypto.DigitalSignature
 import net.corda.core.crypto.SecureHash
+import net.corda.core.crypto.internal.AliasPrivateKey
 import net.corda.core.crypto.newSecureRandom
 import net.corda.core.flows.*
 import net.corda.core.identity.AbstractParty
@@ -122,14 +122,14 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
                                cacheFactoryPrototype: BindableNamedCacheFactory,
                                protected val versionInfo: VersionInfo,
                                protected val flowManager: FlowManager,
-                               protected val serverThread: AffinityExecutor.ServiceAffinityExecutor,
-                               private val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() {
+                               val serverThread: AffinityExecutor.ServiceAffinityExecutor,
+                               val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() {
 
     protected abstract val log: Logger
     @Suppress("LeakingThis")
     private var tokenizableServices: MutableList<Any>? = mutableListOf(platformClock, this)
 
-    protected val metricRegistry = MetricRegistry()
+    val metricRegistry = MetricRegistry()
     protected val cacheFactory = cacheFactoryPrototype.bindWithConfig(configuration).bindWithMetrics(metricRegistry).tokenize()
     val monitoringService = MonitoringService(metricRegistry).tokenize()
 
@@ -146,7 +146,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
         }
     }
 
-    protected val cordappLoader: CordappLoader = makeCordappLoader(configuration, versionInfo)
+    val cordappLoader: CordappLoader = makeCordappLoader(configuration, versionInfo)
     val schemaService = NodeSchemaService(cordappLoader.cordappSchemas).tokenize()
     val identityService = PersistentIdentityService(cacheFactory).tokenize()
     val database: CordaPersistence = createCordaPersistence(
@@ -777,7 +777,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
         // Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because
         // the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with
         // the identity key. But the infrastructure to make that easy isn't here yet.
-        return BasicHSMKeyManagementService(cacheFactory,identityService, database, cryptoService)
+        return BasicHSMKeyManagementService(cacheFactory, identityService, database, cryptoService)
     }
 
     open fun stop() {
@@ -1008,7 +1008,6 @@ class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogi
             private val _future = openFuture<FlowStateMachine<T>>()
             override val future: CordaFuture<FlowStateMachine<T>>
                 get() = _future
-
         }
         return startFlow(startFlowEvent)
     }
diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkPersistenceTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkPersistenceTests.kt
new file mode 100644
index 0000000000..8a9156af4a
--- /dev/null
+++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkPersistenceTests.kt
@@ -0,0 +1,166 @@
+package net.corda.node.services.statemachine
+
+import net.corda.core.crypto.random63BitValue
+import net.corda.core.flows.FlowLogic
+import net.corda.core.flows.registerCordappFlowFactory
+import net.corda.core.identity.Party
+import net.corda.core.utilities.getOrThrow
+import net.corda.node.services.persistence.checkpoints
+import net.corda.testing.core.ALICE_NAME
+import net.corda.testing.core.BOB_NAME
+import net.corda.testing.core.CHARLIE_NAME
+import net.corda.testing.core.singleIdentity
+import net.corda.testing.internal.LogHelper
+import net.corda.testing.node.InMemoryMessagingNetwork
+import net.corda.testing.node.internal.*
+import org.assertj.core.api.Assertions.assertThat
+import org.junit.After
+import org.junit.Before
+import org.junit.Ignore
+import org.junit.Test
+import rx.Observable
+import java.util.*
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+
+class FlowFrameworkPersistenceTests {
+    companion object {
+        init {
+            LogHelper.setLevel("+net.corda.flow")
+        }
+    }
+
+    private lateinit var mockNet: InternalMockNetwork
+    private val receivedSessionMessages = ArrayList<SessionTransfer>()
+    private lateinit var aliceNode: TestStartedNode
+    private lateinit var bobNode: TestStartedNode
+    private lateinit var notaryIdentity: Party
+    private lateinit var alice: Party
+    private lateinit var bob: Party
+    private lateinit var aliceFlowManager: MockNodeFlowManager
+    private lateinit var bobFlowManager: MockNodeFlowManager
+
+    @Before
+    fun start() {
+        mockNet = InternalMockNetwork(
+                cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"),
+                servicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin()
+        )
+        aliceFlowManager = MockNodeFlowManager()
+        bobFlowManager = MockNodeFlowManager()
+
+        aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME, flowManager = aliceFlowManager))
+        bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME, flowManager = bobFlowManager))
+
+        receivedSessionMessagesObservable().forEach { receivedSessionMessages += it }
+
+        // Extract identities
+        alice = aliceNode.info.singleIdentity()
+        bob = bobNode.info.singleIdentity()
+        notaryIdentity = mockNet.defaultNotaryIdentity
+    }
+
+    @After
+    fun cleanUp() {
+        mockNet.stopNodes()
+        receivedSessionMessages.clear()
+    }
+
+    @Test
+    fun `newly added flow is preserved on restart`() {
+        aliceNode.services.startFlow(NoOpFlow(nonTerminating = true))
+        aliceNode.internals.acceptableLiveFiberCountOnStop = 1
+        val restoredFlow = aliceNode.restartAndGetRestoredFlow<NoOpFlow>()
+        assertThat(restoredFlow.flowStarted).isTrue()
+    }
+
+    @Test
+    fun `flow restarted just after receiving payload`() {
+        bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it)
+                .nonTerminating() }
+        aliceNode.services.startFlow(SendFlow("Hello", bob))
+
+        // We push through just enough messages to get only the payload sent
+        bobNode.pumpReceive()
+        bobNode.internals.disableDBCloseOnStop()
+        bobNode.internals.acceptableLiveFiberCountOnStop = 1
+        bobNode.dispose()
+        mockNet.runNetwork()
+        val restoredFlow = bobNode.restartAndGetRestoredFlow<InitiatedReceiveFlow>()
+        assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
+    }
+
+    @Test
+    fun `flow loaded from checkpoint will respond to messages from before start`() {
+        aliceNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
+        bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow
+        val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>()
+        assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
+    }
+
+    @Ignore("Some changes in startup order make this test's assumptions fail.")
+    @Test
+    fun `flow with send will resend on interrupted restart`() {
+        val payload = random63BitValue()
+        val payload2 = random63BitValue()
+
+        var sentCount = 0
+        mockNet.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ }
+        val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME))
+        val secondFlow = charlieNode.registerCordappFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) }
+        mockNet.runNetwork()
+        val charlie = charlieNode.info.singleIdentity()
+
+        // Kick off first send and receive
+        bobNode.services.startFlow(PingPongFlow(charlie, payload))
+        bobNode.database.transaction {
+            assertEquals(1, bobNode.internals.checkpointStorage.checkpoints().size)
+        }
+        // Make sure the add() has finished initial processing.
+        bobNode.internals.disableDBCloseOnStop()
+        // Restart node and thus reload the checkpoint and resend the message with same UUID
+        bobNode.dispose()
+        bobNode.database.transaction {
+            assertEquals(1, bobNode.internals.checkpointStorage.checkpoints().size) // confirm checkpoint
+            bobNode.services.networkMapCache.clearNetworkMapCache()
+        }
+        val node2b = mockNet.createNode(InternalMockNodeParameters(bobNode.internals.id))
+        bobNode.internals.manuallyCloseDB()
+        val (firstAgain, fut1) = node2b.getSingleFlow<PingPongFlow>()
+        // Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync.
+        mockNet.runNetwork()
+        fut1.getOrThrow()
+
+        val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer }
+        // Check flows completed cleanly and didn't get out of phase
+        assertEquals(4, receivedCount, "Flow should have exchanged 4 unique messages")// Two messages each way
+        // can't give a precise value as every addMessageHandler re-runs the undelivered messages
+        assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages")
+        node2b.database.transaction {
+            assertEquals(0, node2b.internals.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended")
+        }
+        charlieNode.database.transaction {
+            assertEquals(0, charlieNode.internals.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended")
+        }
+        assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
+        assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
+        assertEquals(payload, secondFlow.getOrThrow().receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
+        assertEquals(payload + 1, secondFlow.getOrThrow().receivedPayload2, "Received payload does not match the expected second value on Node 2")
+    }
+
+    ////////////////////////////////////////////////////////////////////////////////////////////////////////////
+    //region Helpers
+
+    private inline fun <reified P : FlowLogic<*>> TestStartedNode.restartAndGetRestoredFlow(): P {
+        val newNode = mockNet.restartNode(this)
+        newNode.internals.acceptableLiveFiberCountOnStop = 1
+        mockNet.runNetwork()
+        return newNode.getSingleFlow<P>().first
+    }
+
+    private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> {
+        return mockNet.messagingNetwork.receivedMessages.toSessionTransfers()
+    }
+
+    //endregion Helpers
+}
diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt
index 53ec7d9f2a..23b919eb66 100644
--- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt
+++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt
@@ -40,16 +40,13 @@ import org.assertj.core.api.Assertions.assertThatThrownBy
 import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType
 import org.junit.After
 import org.junit.Before
-import org.junit.Ignore
 import org.junit.Test
 import rx.Notification
 import rx.Observable
 import java.time.Instant
 import java.util.*
 import kotlin.reflect.KClass
-import kotlin.test.assertEquals
 import kotlin.test.assertFailsWith
-import kotlin.test.assertTrue
 
 class FlowFrameworkTests {
     companion object {
@@ -449,320 +446,142 @@ class FlowFrameworkTests {
 
     private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
 
-    private fun TestStartedNode.sendSessionMessage(message: SessionMessage, destination: Party) {
-        services.networkService.apply {
-            val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList()))
-            send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address)
-        }
-    }
-
     private fun assertSessionTransfers(vararg expected: SessionTransfer) {
         assertThat(receivedSessionMessages).containsExactly(*expected)
     }
 
-    //endregion Helpers
-}
+    private val FlowLogic<*>.progressSteps: CordaFuture<List<Notification<ProgressTracker.Step>>>
+        get() {
+            return progressTracker!!.changes
+                    .ofType(Change.Position::class.java)
+                    .map { it.newStep }
+                    .materialize()
+                    .toList()
+                    .toFuture()
+        }
 
-class FlowFrameworkTripartyTests {
+    @InitiatingFlow
+    private class WaitForOtherSideEndBeforeSendAndReceive(val otherParty: Party,
+                                                          @Transient val receivedOtherFlowEnd: Semaphore) : FlowLogic<Unit>() {
+        @Suspendable
+        override fun call() {
+            // Kick off the flow on the other side ...
+            val session = initiateFlow(otherParty)
+            session.send(1)
+            // ... then pause this one until it's received the session-end message from the other side
+            receivedOtherFlowEnd.acquire()
+            session.sendAndReceive<Int>(2)
+        }
+    }
 
-    companion object {
+    // we need brand new class for a flow to fail, so here it is
+    @InitiatingFlow
+    private open class NeverRegisteredFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic<FlowInfo>() {
         init {
-            LogHelper.setLevel("+net.corda.flow")
+            require(otherParties.isNotEmpty())
         }
 
-        private lateinit var mockNet: InternalMockNetwork
-        private lateinit var aliceNode: TestStartedNode
-        private lateinit var bobNode: TestStartedNode
-        private lateinit var charlieNode: TestStartedNode
-        private lateinit var alice: Party
-        private lateinit var bob: Party
-        private lateinit var charlie: Party
-        private lateinit var notaryIdentity: Party
-        private val receivedSessionMessages = ArrayList<SessionTransfer>()
-    }
-
-    @Before
-    fun setUpGlobalMockNet() {
-        mockNet = InternalMockNetwork(
-                cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"),
-                servicePeerAllocationStrategy = RoundRobin()
-        )
-
-        aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME))
-        bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME))
-        charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME))
-
-
-        // Extract identities
-        alice = aliceNode.info.singleIdentity()
-        bob = bobNode.info.singleIdentity()
-        charlie = charlieNode.info.singleIdentity()
-        notaryIdentity = mockNet.defaultNotaryIdentity
-
-        receivedSessionMessagesObservable().forEach { receivedSessionMessages += it }
-    }
-
-    @After
-    fun cleanUp() {
-        mockNet.stopNodes()
-        receivedSessionMessages.clear()
-    }
-
-    private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> {
-        return mockNet.messagingNetwork.receivedMessages.toSessionTransfers()
-    }
-
-    @Test
-    fun `sending to multiple parties`() {
-        bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
-        charlieNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
-        val payload = "Hello World"
-        aliceNode.services.startFlow(SendFlow(payload, bob, charlie))
-        mockNet.runNetwork()
-        bobNode.internals.acceptableLiveFiberCountOnStop = 1
-        charlieNode.internals.acceptableLiveFiberCountOnStop = 1
-        val bobFlow = bobNode.getSingleFlow<InitiatedReceiveFlow>().first
-        val charlieFlow = charlieNode.getSingleFlow<InitiatedReceiveFlow>().first
-        assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload)
-        assertThat(charlieFlow.receivedPayloads[0]).isEqualTo(payload)
-
-        assertSessionTransfers(bobNode,
-                aliceNode sent sessionInit(SendFlow::class, payload = payload) to bobNode,
-                bobNode sent sessionConfirm() to aliceNode,
-                aliceNode sent normalEnd to bobNode
-                //There's no session end from the other flows as they're manually suspended
-        )
-
-        assertSessionTransfers(charlieNode,
-                aliceNode sent sessionInit(SendFlow::class, payload = payload) to charlieNode,
-                charlieNode sent sessionConfirm() to aliceNode,
-                aliceNode sent normalEnd to charlieNode
-                //There's no session end from the other flows as they're manually suspended
-        )
-    }
-
-    @Test
-    fun `receiving from multiple parties`() {
-        val bobPayload = "Test 1"
-        val charliePayload = "Test 2"
-        bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) }
-        charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) }
-        val multiReceiveFlow = ReceiveFlow(bob, charlie).nonTerminating()
-        aliceNode.services.startFlow(multiReceiveFlow)
-        aliceNode.internals.acceptableLiveFiberCountOnStop = 1
-        mockNet.runNetwork()
-        assertThat(multiReceiveFlow.receivedPayloads[0]).isEqualTo(bobPayload)
-        assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(charliePayload)
-
-        assertSessionTransfers(bobNode,
-                aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
-                bobNode sent sessionConfirm() to aliceNode,
-                bobNode sent sessionData(bobPayload) to aliceNode,
-                bobNode sent normalEnd to aliceNode
-        )
-
-        assertSessionTransfers(charlieNode,
-                aliceNode sent sessionInit(ReceiveFlow::class) to charlieNode,
-                charlieNode sent sessionConfirm() to aliceNode,
-                charlieNode sent sessionData(charliePayload) to aliceNode,
-                charlieNode sent normalEnd to aliceNode
-        )
-    }
-
-    @Test
-    fun `FlowException only propagated to parent`() {
-        charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } }
-        bobNode.registerCordappFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) }
-        val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob))
-        mockNet.runNetwork()
-        assertThatExceptionOfType(UnexpectedFlowEndException::class.java)
-                .isThrownBy { receivingFiber.resultFuture.getOrThrow() }
-    }
-
-    @Test
-    fun `FlowException thrown and there is a 3rd unrelated party flow`() {
-        // Bob will send its payload and then block waiting for the receive from Alice. Meanwhile Alice will move
-        // onto Charlie which will throw the exception
-        val node2Fiber = bobNode
-                .registerCordappFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") }
-                .map { it.stateMachine }
-        charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } }
-
-        val aliceFiber = aliceNode.services.startFlow(ReceiveFlow(bob, charlie)) as FlowStateMachineImpl
-        mockNet.runNetwork()
-
-        // Alice will terminate with the error it received from Charlie but it won't propagate that to Bob (as it's
-        // not relevant to it) but it will end its session with it
-        assertThatExceptionOfType(MyFlowException::class.java).isThrownBy {
-            aliceFiber.resultFuture.getOrThrow()
-        }
-        val bobResultFuture = node2Fiber.getOrThrow().resultFuture
-        assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy {
-            bobResultFuture.getOrThrow()
-        }
-
-        assertSessionTransfers(bobNode,
-                aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
-                bobNode sent sessionConfirm() to aliceNode,
-                bobNode sent sessionData("Hello") to aliceNode,
-                aliceNode sent errorMessage() to bobNode
-        )
-    }
-
-    private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
-
-    private fun assertSessionTransfers(node: TestStartedNode, vararg expected: SessionTransfer): List<SessionTransfer> {
-        val actualForNode = receivedSessionMessages.filter { it.from == node.internals.id || it.to == node.network.myAddress }
-        assertThat(actualForNode).containsExactly(*expected)
-        return actualForNode
-    }
-
-}
-
-class FlowFrameworkPersistenceTests {
-    companion object {
-        init {
-            LogHelper.setLevel("+net.corda.flow")
+        @Suspendable
+        override fun call(): FlowInfo {
+            val flowInfos = otherParties.map {
+                val session = initiateFlow(it)
+                session.send(payload)
+                session.getCounterpartyFlowInfo()
+            }.toList()
+            return flowInfos.first()
         }
     }
 
-    private lateinit var mockNet: InternalMockNetwork
-    private val receivedSessionMessages = ArrayList<SessionTransfer>()
-    private lateinit var aliceNode: TestStartedNode
-    private lateinit var bobNode: TestStartedNode
-    private lateinit var notaryIdentity: Party
-    private lateinit var alice: Party
-    private lateinit var bob: Party
-    private lateinit var aliceFlowManager: MockNodeFlowManager
-    private lateinit var bobFlowManager: MockNodeFlowManager
-
-    @Before
-    fun start() {
-        mockNet = InternalMockNetwork(
-                cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"),
-                servicePeerAllocationStrategy = RoundRobin()
-        )
-        aliceFlowManager = MockNodeFlowManager()
-        bobFlowManager = MockNodeFlowManager()
-
-        aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME, flowManager = aliceFlowManager))
-        bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME, flowManager = bobFlowManager))
-
-        receivedSessionMessagesObservable().forEach { receivedSessionMessages += it }
-
-        // Extract identities
-        alice = aliceNode.info.singleIdentity()
-        bob = bobNode.info.singleIdentity()
-        notaryIdentity = mockNet.defaultNotaryIdentity
-    }
-
-    @After
-    fun cleanUp() {
-        mockNet.stopNodes()
-        receivedSessionMessages.clear()
-    }
-
-    @Test
-    fun `newly added flow is preserved on restart`() {
-        aliceNode.services.startFlow(NoOpFlow(nonTerminating = true))
-        aliceNode.internals.acceptableLiveFiberCountOnStop = 1
-        val restoredFlow = aliceNode.restartAndGetRestoredFlow<NoOpFlow>()
-        assertThat(restoredFlow.flowStarted).isTrue()
-    }
-
-    @Test
-    fun `flow restarted just after receiving payload`() {
-        bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
-        aliceNode.services.startFlow(SendFlow("Hello", bob))
-
-        // We push through just enough messages to get only the payload sent
-        bobNode.pumpReceive()
-        bobNode.internals.disableDBCloseOnStop()
-        bobNode.internals.acceptableLiveFiberCountOnStop = 1
-        bobNode.dispose()
-        mockNet.runNetwork()
-        val restoredFlow = bobNode.restartAndGetRestoredFlow<InitiatedReceiveFlow>()
-        assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
-    }
-
-    @Test
-    fun `flow loaded from checkpoint will respond to messages from before start`() {
-        aliceNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
-        bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow
-        val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>()
-        assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
-    }
-
-    @Ignore("Some changes in startup order make this test's assumptions fail.")
-    @Test
-    fun `flow with send will resend on interrupted restart`() {
-        val payload = random63BitValue()
-        val payload2 = random63BitValue()
-
-        var sentCount = 0
-        mockNet.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ }
-        val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME))
-        val secondFlow = charlieNode.registerCordappFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) }
-        mockNet.runNetwork()
-        val charlie = charlieNode.info.singleIdentity()
-
-        // Kick off first send and receive
-        bobNode.services.startFlow(PingPongFlow(charlie, payload))
-        bobNode.database.transaction {
-            assertEquals(1, bobNode.internals.checkpointStorage.checkpoints().size)
+    private object WaitingFlows {
+        @InitiatingFlow
+        class Waiter(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<SignedTransaction>() {
+            @Suspendable
+            override fun call(): SignedTransaction {
+                val otherPartySession = initiateFlow(otherParty)
+                otherPartySession.send(stx)
+                return waitForLedgerCommit(stx.id)
+            }
         }
-        // Make sure the add() has finished initial processing.
-        bobNode.internals.disableDBCloseOnStop()
-        // Restart node and thus reload the checkpoint and resend the message with same UUID
-        bobNode.dispose()
-        bobNode.database.transaction {
-            assertEquals(1, bobNode.internals.checkpointStorage.checkpoints().size) // confirm checkpoint
-            bobNode.services.networkMapCache.clearNetworkMapCache()
-        }
-        val node2b = mockNet.createNode(InternalMockNodeParameters(bobNode.internals.id))
-        bobNode.internals.manuallyCloseDB()
-        val (firstAgain, fut1) = node2b.getSingleFlow<PingPongFlow>()
-        // Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync.
-        mockNet.runNetwork()
-        fut1.getOrThrow()
 
-        val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer }
-        // Check flows completed cleanly and didn't get out of phase
-        assertEquals(4, receivedCount, "Flow should have exchanged 4 unique messages")// Two messages each way
-        // can't give a precise value as every addMessageHandler re-runs the undelivered messages
-        assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages")
-        node2b.database.transaction {
-            assertEquals(0, node2b.internals.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended")
+        class Committer(val otherPartySession: FlowSession, val throwException: (() -> Exception)? = null) : FlowLogic<SignedTransaction>() {
+            @Suspendable
+            override fun call(): SignedTransaction {
+                val stx = otherPartySession.receive<SignedTransaction>().unwrap { it }
+                if (throwException != null) throw throwException.invoke()
+                return subFlow(FinalityFlow(stx, setOf(otherPartySession.counterparty)))
+            }
         }
-        charlieNode.database.transaction {
-            assertEquals(0, charlieNode.internals.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended")
-        }
-        assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
-        assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
-        assertEquals(payload, secondFlow.getOrThrow().receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
-        assertEquals(payload + 1, secondFlow.getOrThrow().receivedPayload2, "Received payload does not match the expected second value on Node 2")
     }
 
-    ////////////////////////////////////////////////////////////////////////////////////////////////////////////
-    //region Helpers
-
-    private inline fun <reified P : FlowLogic<*>> TestStartedNode.restartAndGetRestoredFlow(): P {
-        val newNode = mockNet.restartNode(this)
-        newNode.internals.acceptableLiveFiberCountOnStop = 1
-        mockNet.runNetwork()
-        return newNode.getSingleFlow<P>().first
+    private class LazyServiceHubAccessFlow : FlowLogic<Unit>() {
+        val lazyTime: Instant by lazy { serviceHub.clock.instant() }
+        @Suspendable
+        override fun call() = Unit
     }
 
-    private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> {
-        return mockNet.messagingNetwork.receivedMessages.toSessionTransfers()
+    private interface CustomInterface
+
+    private class CustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty)
+
+    @InitiatingFlow
+    private class IncorrectCustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty)
+
+    @InitiatingFlow
+    private class VaultQueryFlow(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<List<StateAndRef<ContractState>>>() {
+        @Suspendable
+        override fun call(): List<StateAndRef<ContractState>> {
+            val otherPartySession = initiateFlow(otherParty)
+            otherPartySession.send(stx)
+            // hold onto reference here to force checkpoint of vaultService and thus
+            // prove it is registered as a tokenizableService in the node
+            val vaultQuerySvc = serviceHub.vaultService
+            waitForLedgerCommit(stx.id)
+            return vaultQuerySvc.queryBy<ContractState>().states
+        }
+    }
+
+    @InitiatingFlow(version = 2)
+    private class UpgradedFlow(val otherParty: Party, val otherPartySession: FlowSession? = null) : FlowLogic<Pair<Any, Int>>() {
+        constructor(otherPartySession: FlowSession) : this(otherPartySession.counterparty, otherPartySession)
+
+        @Suspendable
+        override fun call(): Pair<Any, Int> {
+            val otherPartySession = this.otherPartySession ?: initiateFlow(otherParty)
+            val received = otherPartySession.receive<Any>().unwrap { it }
+            val otherFlowVersion = otherPartySession.getCounterpartyFlowInfo().flowVersion
+            return Pair(received, otherFlowVersion)
+        }
+    }
+
+    private class SingleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() {
+        @Suspendable
+        override fun call() {
+            val payload = otherPartySession.receive<String>().unwrap { it }
+            subFlow(InlinedSendFlow(payload + payload, otherPartySession))
+        }
+    }
+
+    private class DoubleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() {
+        @Suspendable
+        override fun call() {
+            subFlow(SingleInlinedSubFlow(otherPartySession))
+        }
+    }
+
+    private data class NonSerialisableData(val a: Int)
+    private class NonSerialisableFlowException(@Suppress("unused") val data: NonSerialisableData) : FlowException()
+
+    private class InlinedSendFlow(val payload: String, val otherPartySession: FlowSession) : FlowLogic<Unit>() {
+        @Suspendable
+        override fun call() = otherPartySession.send(payload)
     }
 
     //endregion Helpers
 }
 
-private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, "")))
+internal fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, "")))
 
-private inline fun <reified P : FlowLogic<*>> TestStartedNode.getSingleFlow(): Pair<P, CordaFuture<*>> {
+internal inline fun <reified P : FlowLogic<*>> TestStartedNode.getSingleFlow(): Pair<P, CordaFuture<*>> {
     return smm.findStateMachines(P::class.java).single()
 }
 
@@ -786,7 +605,7 @@ private fun sanitise(message: SessionMessage) = when (message) {
     }
 }
 
-private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
+internal fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
     return filter { it.getMessage().topic == FlowMessagingImpl.sessionTopic }.map {
         val from = it.sender.id
         val message = it.messageData.deserialize<SessionMessage>()
@@ -794,12 +613,19 @@ private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<Session
     }
 }
 
-private fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0))
+internal fun TestStartedNode.sendSessionMessage(message: SessionMessage, destination: Party) {
+    services.networkService.apply {
+        val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList()))
+        send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address)
+    }
+}
 
-private infix fun TestStartedNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message)
-private infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer = SessionTransfer(first, second, node.network.myAddress)
+internal fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0))
 
-private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) {
+internal infix fun TestStartedNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message)
+internal infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer = SessionTransfer(first, second, node.network.myAddress)
+
+internal data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) {
     val isPayloadTransfer: Boolean
         get() =
             message is ExistingSessionMessage && message.payload is DataSessionMessage ||
@@ -808,40 +634,14 @@ private data class SessionTransfer(val from: Int, val message: SessionMessage, v
     override fun toString(): String = "$from sent $message to $to"
 }
 
-
-private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
+internal fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
     return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
 }
 
-private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize()))
-
-
-private val FlowLogic<*>.progressSteps: CordaFuture<List<Notification<ProgressTracker.Step>>>
-    get() {
-        return progressTracker!!.changes
-                .ofType(Change.Position::class.java)
-                .map { it.newStep }
-                .materialize()
-                .toList()
-                .toFuture()
-    }
+internal fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize()))
 
 @InitiatingFlow
-private class WaitForOtherSideEndBeforeSendAndReceive(val otherParty: Party,
-                                                      @Transient val receivedOtherFlowEnd: Semaphore) : FlowLogic<Unit>() {
-    @Suspendable
-    override fun call() {
-        // Kick off the flow on the other side ...
-        val session = initiateFlow(otherParty)
-        session.send(1)
-        // ... then pause this one until it's received the session-end message from the other side
-        receivedOtherFlowEnd.acquire()
-        session.sendAndReceive<Int>(2)
-    }
-}
-
-@InitiatingFlow
-private open class SendFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic<FlowInfo>() {
+internal open class SendFlow(private val payload: Any, private vararg val otherParties: Party) : FlowLogic<FlowInfo>() {
     init {
         require(otherParties.isNotEmpty())
     }
@@ -857,46 +657,7 @@ private open class SendFlow(val payload: Any, vararg val otherParties: Party) :
     }
 }
 
-// we need brand new class for a flow to fail, so here it is
-@InitiatingFlow
-private open class NeverRegisteredFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic<FlowInfo>() {
-    init {
-        require(otherParties.isNotEmpty())
-    }
-
-    @Suspendable
-    override fun call(): FlowInfo {
-        val flowInfos = otherParties.map {
-            val session = initiateFlow(it)
-            session.send(payload)
-            session.getCounterpartyFlowInfo()
-        }.toList()
-        return flowInfos.first()
-    }
-}
-
-private object WaitingFlows {
-    @InitiatingFlow
-    class Waiter(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<SignedTransaction>() {
-        @Suspendable
-        override fun call(): SignedTransaction {
-            val otherPartySession = initiateFlow(otherParty)
-            otherPartySession.send(stx)
-            return waitForLedgerCommit(stx.id)
-        }
-    }
-
-    class Committer(val otherPartySession: FlowSession, val throwException: (() -> Exception)? = null) : FlowLogic<SignedTransaction>() {
-        @Suspendable
-        override fun call(): SignedTransaction {
-            val stx = otherPartySession.receive<SignedTransaction>().unwrap { it }
-            if (throwException != null) throw throwException.invoke()
-            return subFlow(FinalityFlow(stx, setOf(otherPartySession.counterparty)))
-        }
-    }
-}
-
-private class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic<Unit>() {
+internal class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic<Unit>() {
     @Transient
     var flowStarted = false
 
@@ -909,7 +670,7 @@ private class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic<Unit>()
     }
 }
 
-private class InitiatedReceiveFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() {
+internal class InitiatedReceiveFlow(private val otherPartySession: FlowSession) : FlowLogic<Unit>() {
     object START_STEP : ProgressTracker.Step("Starting")
     object RECEIVED_STEP : ProgressTracker.Step("Received")
 
@@ -934,26 +695,13 @@ private class InitiatedReceiveFlow(val otherPartySession: FlowSession) : FlowLog
     }
 }
 
-private class LazyServiceHubAccessFlow : FlowLogic<Unit>() {
-    val lazyTime: Instant by lazy { serviceHub.clock.instant() }
-    @Suspendable
-    override fun call() = Unit
-}
-
-private open class InitiatedSendFlow(val payload: Any, val otherPartySession: FlowSession) : FlowLogic<Unit>() {
+internal open class InitiatedSendFlow(private val payload: Any, private val otherPartySession: FlowSession) : FlowLogic<Unit>() {
     @Suspendable
     override fun call() = otherPartySession.send(payload)
 }
 
-private interface CustomInterface
-
-private class CustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty)
-
 @InitiatingFlow
-private class IncorrectCustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty)
-
-@InitiatingFlow
-private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
+internal class ReceiveFlow(private vararg val otherParties: Party) : FlowLogic<Unit>() {
     object START_STEP : ProgressTracker.Step("Starting")
     object RECEIVED_STEP : ProgressTracker.Step("Received")
 
@@ -982,72 +730,23 @@ private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
     }
 }
 
-private class MyFlowException(override val message: String) : FlowException() {
+internal class MyFlowException(override val message: String) : FlowException() {
     override fun equals(other: Any?): Boolean = other is MyFlowException && other.message == this.message
     override fun hashCode(): Int = message.hashCode()
 }
 
 @InitiatingFlow
-private class VaultQueryFlow(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<List<StateAndRef<ContractState>>>() {
-    @Suspendable
-    override fun call(): List<StateAndRef<ContractState>> {
-        val otherPartySession = initiateFlow(otherParty)
-        otherPartySession.send(stx)
-        // hold onto reference here to force checkpoint of vaultService and thus
-        // prove it is registered as a tokenizableService in the node
-        val vaultQuerySvc = serviceHub.vaultService
-        waitForLedgerCommit(stx.id)
-        return vaultQuerySvc.queryBy<ContractState>().states
-    }
-}
-
-@InitiatingFlow(version = 2)
-private class UpgradedFlow(val otherParty: Party, val otherPartySession: FlowSession? = null) : FlowLogic<Pair<Any, Int>>() {
-    constructor(otherPartySession: FlowSession) : this(otherPartySession.counterparty, otherPartySession)
-
-    @Suspendable
-    override fun call(): Pair<Any, Int> {
-        val otherPartySession = this.otherPartySession ?: initiateFlow(otherParty)
-        val received = otherPartySession.receive<Any>().unwrap { it }
-        val otherFlowVersion = otherPartySession.getCounterpartyFlowInfo().flowVersion
-        return Pair(received, otherFlowVersion)
-    }
-}
-
-private class SingleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() {
-    @Suspendable
-    override fun call() {
-        val payload = otherPartySession.receive<String>().unwrap { it }
-        subFlow(InlinedSendFlow(payload + payload, otherPartySession))
-    }
-}
-
-private class DoubleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() {
-    @Suspendable
-    override fun call() {
-        subFlow(SingleInlinedSubFlow(otherPartySession))
-    }
-}
-
-private data class NonSerialisableData(val a: Int)
-private class NonSerialisableFlowException(@Suppress("unused") val data: NonSerialisableData) : FlowException()
-
-@InitiatingFlow
-private class SendAndReceiveFlow(val otherParty: Party, val payload: Any, val otherPartySession: FlowSession? = null) : FlowLogic<Any>() {
+internal class SendAndReceiveFlow(private val otherParty: Party, private val payload: Any, private val otherPartySession: FlowSession? = null) : FlowLogic<Any>() {
     constructor(otherPartySession: FlowSession, payload: Any) : this(otherPartySession.counterparty, payload, otherPartySession)
 
     @Suspendable
-    override fun call(): Any = (otherPartySession
-            ?: initiateFlow(otherParty)).sendAndReceive<Any>(payload).unwrap { it }
-}
-
-private class InlinedSendFlow(val payload: String, val otherPartySession: FlowSession) : FlowLogic<Unit>() {
-    @Suspendable
-    override fun call() = otherPartySession.send(payload)
+    override fun call(): Any {
+        return (otherPartySession ?: initiateFlow(otherParty)).sendAndReceive<Any>(payload).unwrap { it }
+    }
 }
 
 @InitiatingFlow
-private class PingPongFlow(val otherParty: Party, val payload: Long, val otherPartySession: FlowSession? = null) : FlowLogic<Unit>() {
+internal class PingPongFlow(private val otherParty: Party, private val payload: Long, private val otherPartySession: FlowSession? = null) : FlowLogic<Unit>() {
     constructor(otherPartySession: FlowSession, payload: Long) : this(otherPartySession.counterparty, payload, otherPartySession)
 
     @Transient
@@ -1063,7 +762,7 @@ private class PingPongFlow(val otherParty: Party, val payload: Long, val otherPa
     }
 }
 
-private class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic<Nothing>() {
+internal class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic<Nothing>() {
     object START_STEP : ProgressTracker.Step("Starting")
 
     override val progressTracker: ProgressTracker = ProgressTracker(START_STEP)
@@ -1075,4 +774,4 @@ private class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic<N
         exceptionThrown = exception()
         throw exceptionThrown
     }
-}
\ No newline at end of file
+}
diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTripartyTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTripartyTests.kt
new file mode 100644
index 0000000000..eb7be9531f
--- /dev/null
+++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTripartyTests.kt
@@ -0,0 +1,178 @@
+package net.corda.node.services.statemachine
+
+import net.corda.core.flows.UnexpectedFlowEndException
+import net.corda.core.flows.registerCordappFlowFactory
+import net.corda.core.identity.Party
+import net.corda.core.internal.concurrent.map
+import net.corda.core.utilities.getOrThrow
+import net.corda.testing.core.ALICE_NAME
+import net.corda.testing.core.BOB_NAME
+import net.corda.testing.core.CHARLIE_NAME
+import net.corda.testing.core.singleIdentity
+import net.corda.testing.internal.LogHelper
+import net.corda.testing.node.InMemoryMessagingNetwork
+import net.corda.testing.node.internal.*
+import org.assertj.core.api.Assertions.assertThat
+import org.assertj.core.api.AssertionsForClassTypes
+import org.junit.After
+import org.junit.Before
+import org.junit.Test
+import rx.Observable
+import java.util.*
+
+class FlowFrameworkTripartyTests {
+    companion object {
+        init {
+            LogHelper.setLevel("+net.corda.flow")
+        }
+
+        private lateinit var mockNet: InternalMockNetwork
+        private lateinit var aliceNode: TestStartedNode
+        private lateinit var bobNode: TestStartedNode
+        private lateinit var charlieNode: TestStartedNode
+        private lateinit var alice: Party
+        private lateinit var bob: Party
+        private lateinit var charlie: Party
+        private lateinit var notaryIdentity: Party
+        private val receivedSessionMessages = ArrayList<SessionTransfer>()
+    }
+
+    @Before
+    fun setUpGlobalMockNet() {
+        mockNet = InternalMockNetwork(
+                cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"),
+                servicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin()
+        )
+
+        aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME))
+        bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME))
+        charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME))
+
+
+        // Extract identities
+        alice = aliceNode.info.singleIdentity()
+        bob = bobNode.info.singleIdentity()
+        charlie = charlieNode.info.singleIdentity()
+        notaryIdentity = mockNet.defaultNotaryIdentity
+
+        receivedSessionMessagesObservable().forEach { receivedSessionMessages += it }
+    }
+
+    @After
+    fun cleanUp() {
+        mockNet.stopNodes()
+        receivedSessionMessages.clear()
+    }
+
+    private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> {
+        return mockNet.messagingNetwork.receivedMessages.toSessionTransfers()
+    }
+
+    @Test
+    fun `sending to multiple parties`() {
+        bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it)
+                .nonTerminating() }
+        charlieNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it)
+                .nonTerminating() }
+        val payload = "Hello World"
+        aliceNode.services.startFlow(SendFlow(payload, bob, charlie))
+        mockNet.runNetwork()
+        bobNode.internals.acceptableLiveFiberCountOnStop = 1
+        charlieNode.internals.acceptableLiveFiberCountOnStop = 1
+        val bobFlow = bobNode.getSingleFlow<InitiatedReceiveFlow>().first
+        val charlieFlow = charlieNode.getSingleFlow<InitiatedReceiveFlow>().first
+        assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload)
+        assertThat(charlieFlow.receivedPayloads[0]).isEqualTo(payload)
+
+        assertSessionTransfers(bobNode,
+                aliceNode sent sessionInit(SendFlow::class, payload = payload) to bobNode,
+                bobNode sent sessionConfirm() to aliceNode,
+                aliceNode sent normalEnd to bobNode
+                //There's no session end from the other flows as they're manually suspended
+        )
+
+        assertSessionTransfers(charlieNode,
+                aliceNode sent sessionInit(SendFlow::class, payload = payload) to charlieNode,
+                charlieNode sent sessionConfirm() to aliceNode,
+                aliceNode sent normalEnd to charlieNode
+                //There's no session end from the other flows as they're manually suspended
+        )
+    }
+
+    @Test
+    fun `receiving from multiple parties`() {
+        val bobPayload = "Test 1"
+        val charliePayload = "Test 2"
+        bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) }
+        charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) }
+        val multiReceiveFlow = ReceiveFlow(bob, charlie).nonTerminating()
+        aliceNode.services.startFlow(multiReceiveFlow)
+        aliceNode.internals.acceptableLiveFiberCountOnStop = 1
+        mockNet.runNetwork()
+        assertThat(multiReceiveFlow.receivedPayloads[0]).isEqualTo(bobPayload)
+        assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(charliePayload)
+
+        assertSessionTransfers(bobNode,
+                aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
+                bobNode sent sessionConfirm() to aliceNode,
+                bobNode sent sessionData(bobPayload) to aliceNode,
+                bobNode sent normalEnd to aliceNode
+        )
+
+        assertSessionTransfers(charlieNode,
+                aliceNode sent sessionInit(ReceiveFlow::class) to charlieNode,
+                charlieNode sent sessionConfirm() to aliceNode,
+                charlieNode sent sessionData(charliePayload) to aliceNode,
+                charlieNode sent normalEnd to aliceNode
+        )
+    }
+
+    @Test
+    fun `FlowException only propagated to parent`() {
+        charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } }
+        bobNode.registerCordappFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) }
+        val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob))
+        mockNet.runNetwork()
+        AssertionsForClassTypes.assertThatExceptionOfType(UnexpectedFlowEndException::class.java)
+                .isThrownBy { receivingFiber.resultFuture.getOrThrow() }
+    }
+
+    @Test
+    fun `FlowException thrown and there is a 3rd unrelated party flow`() {
+        // Bob will send its payload and then block waiting for the receive from Alice. Meanwhile Alice will move
+        // onto Charlie which will throw the exception
+        val node2Fiber = bobNode
+                .registerCordappFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") }
+                .map { it.stateMachine }
+        charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } }
+
+        val aliceFiber = aliceNode.services.startFlow(ReceiveFlow(bob, charlie)) as FlowStateMachineImpl
+        mockNet.runNetwork()
+
+        // Alice will terminate with the error it received from Charlie but it won't propagate that to Bob (as it's
+        // not relevant to it) but it will end its session with it
+        AssertionsForClassTypes.assertThatExceptionOfType(MyFlowException::class.java)
+                .isThrownBy {
+            aliceFiber.resultFuture.getOrThrow()
+        }
+        val bobResultFuture = node2Fiber.getOrThrow().resultFuture
+        AssertionsForClassTypes.assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy {
+            bobResultFuture.getOrThrow()
+        }
+
+        assertSessionTransfers(bobNode,
+                aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
+                bobNode sent sessionConfirm() to aliceNode,
+                bobNode sent sessionData("Hello") to aliceNode,
+                aliceNode sent errorMessage() to bobNode
+        )
+    }
+
+    private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
+
+    private fun assertSessionTransfers(node: TestStartedNode, vararg expected: SessionTransfer): List<SessionTransfer> {
+        val actualForNode = receivedSessionMessages.filter { it.from == node.internals.id || it.to == node.network.myAddress }
+        assertThat(actualForNode).containsExactly(*expected)
+        return actualForNode
+    }
+}