Back porting clean up of FlowFrameworkTests.kt made in ENT (#4218)

This commit is contained in:
Shams Asari 2018-11-12 18:38:47 +00:00 committed by GitHub
parent 369f23e306
commit 1c012f6403
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 488 additions and 446 deletions

View File

@ -9,9 +9,9 @@ import net.corda.confidential.SwapIdentitiesHandler
import net.corda.core.CordaException
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.crypto.internal.AliasPrivateKey
import net.corda.core.crypto.DigitalSignature
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.internal.AliasPrivateKey
import net.corda.core.crypto.newSecureRandom
import net.corda.core.flows.*
import net.corda.core.identity.AbstractParty
@ -122,14 +122,14 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
cacheFactoryPrototype: BindableNamedCacheFactory,
protected val versionInfo: VersionInfo,
protected val flowManager: FlowManager,
protected val serverThread: AffinityExecutor.ServiceAffinityExecutor,
private val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() {
val serverThread: AffinityExecutor.ServiceAffinityExecutor,
val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() {
protected abstract val log: Logger
@Suppress("LeakingThis")
private var tokenizableServices: MutableList<Any>? = mutableListOf(platformClock, this)
protected val metricRegistry = MetricRegistry()
val metricRegistry = MetricRegistry()
protected val cacheFactory = cacheFactoryPrototype.bindWithConfig(configuration).bindWithMetrics(metricRegistry).tokenize()
val monitoringService = MonitoringService(metricRegistry).tokenize()
@ -146,7 +146,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
}
}
protected val cordappLoader: CordappLoader = makeCordappLoader(configuration, versionInfo)
val cordappLoader: CordappLoader = makeCordappLoader(configuration, versionInfo)
val schemaService = NodeSchemaService(cordappLoader.cordappSchemas).tokenize()
val identityService = PersistentIdentityService(cacheFactory).tokenize()
val database: CordaPersistence = createCordaPersistence(
@ -777,7 +777,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
// Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because
// the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with
// the identity key. But the infrastructure to make that easy isn't here yet.
return BasicHSMKeyManagementService(cacheFactory,identityService, database, cryptoService)
return BasicHSMKeyManagementService(cacheFactory, identityService, database, cryptoService)
}
open fun stop() {
@ -1008,7 +1008,6 @@ class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogi
private val _future = openFuture<FlowStateMachine<T>>()
override val future: CordaFuture<FlowStateMachine<T>>
get() = _future
}
return startFlow(startFlowEvent)
}

View File

@ -0,0 +1,166 @@
package net.corda.node.services.statemachine
import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.registerCordappFlowFactory
import net.corda.core.identity.Party
import net.corda.core.utilities.getOrThrow
import net.corda.node.services.persistence.checkpoints
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.CHARLIE_NAME
import net.corda.testing.core.singleIdentity
import net.corda.testing.internal.LogHelper
import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.internal.*
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
import org.junit.Ignore
import org.junit.Test
import rx.Observable
import java.util.*
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class FlowFrameworkPersistenceTests {
companion object {
init {
LogHelper.setLevel("+net.corda.flow")
}
}
private lateinit var mockNet: InternalMockNetwork
private val receivedSessionMessages = ArrayList<SessionTransfer>()
private lateinit var aliceNode: TestStartedNode
private lateinit var bobNode: TestStartedNode
private lateinit var notaryIdentity: Party
private lateinit var alice: Party
private lateinit var bob: Party
private lateinit var aliceFlowManager: MockNodeFlowManager
private lateinit var bobFlowManager: MockNodeFlowManager
@Before
fun start() {
mockNet = InternalMockNetwork(
cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"),
servicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin()
)
aliceFlowManager = MockNodeFlowManager()
bobFlowManager = MockNodeFlowManager()
aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME, flowManager = aliceFlowManager))
bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME, flowManager = bobFlowManager))
receivedSessionMessagesObservable().forEach { receivedSessionMessages += it }
// Extract identities
alice = aliceNode.info.singleIdentity()
bob = bobNode.info.singleIdentity()
notaryIdentity = mockNet.defaultNotaryIdentity
}
@After
fun cleanUp() {
mockNet.stopNodes()
receivedSessionMessages.clear()
}
@Test
fun `newly added flow is preserved on restart`() {
aliceNode.services.startFlow(NoOpFlow(nonTerminating = true))
aliceNode.internals.acceptableLiveFiberCountOnStop = 1
val restoredFlow = aliceNode.restartAndGetRestoredFlow<NoOpFlow>()
assertThat(restoredFlow.flowStarted).isTrue()
}
@Test
fun `flow restarted just after receiving payload`() {
bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it)
.nonTerminating() }
aliceNode.services.startFlow(SendFlow("Hello", bob))
// We push through just enough messages to get only the payload sent
bobNode.pumpReceive()
bobNode.internals.disableDBCloseOnStop()
bobNode.internals.acceptableLiveFiberCountOnStop = 1
bobNode.dispose()
mockNet.runNetwork()
val restoredFlow = bobNode.restartAndGetRestoredFlow<InitiatedReceiveFlow>()
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
}
@Test
fun `flow loaded from checkpoint will respond to messages from before start`() {
aliceNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow
val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>()
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
}
@Ignore("Some changes in startup order make this test's assumptions fail.")
@Test
fun `flow with send will resend on interrupted restart`() {
val payload = random63BitValue()
val payload2 = random63BitValue()
var sentCount = 0
mockNet.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ }
val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME))
val secondFlow = charlieNode.registerCordappFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) }
mockNet.runNetwork()
val charlie = charlieNode.info.singleIdentity()
// Kick off first send and receive
bobNode.services.startFlow(PingPongFlow(charlie, payload))
bobNode.database.transaction {
assertEquals(1, bobNode.internals.checkpointStorage.checkpoints().size)
}
// Make sure the add() has finished initial processing.
bobNode.internals.disableDBCloseOnStop()
// Restart node and thus reload the checkpoint and resend the message with same UUID
bobNode.dispose()
bobNode.database.transaction {
assertEquals(1, bobNode.internals.checkpointStorage.checkpoints().size) // confirm checkpoint
bobNode.services.networkMapCache.clearNetworkMapCache()
}
val node2b = mockNet.createNode(InternalMockNodeParameters(bobNode.internals.id))
bobNode.internals.manuallyCloseDB()
val (firstAgain, fut1) = node2b.getSingleFlow<PingPongFlow>()
// Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync.
mockNet.runNetwork()
fut1.getOrThrow()
val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer }
// Check flows completed cleanly and didn't get out of phase
assertEquals(4, receivedCount, "Flow should have exchanged 4 unique messages")// Two messages each way
// can't give a precise value as every addMessageHandler re-runs the undelivered messages
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages")
node2b.database.transaction {
assertEquals(0, node2b.internals.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended")
}
charlieNode.database.transaction {
assertEquals(0, charlieNode.internals.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended")
}
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
assertEquals(payload, secondFlow.getOrThrow().receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
assertEquals(payload + 1, secondFlow.getOrThrow().receivedPayload2, "Received payload does not match the expected second value on Node 2")
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////
//region Helpers
private inline fun <reified P : FlowLogic<*>> TestStartedNode.restartAndGetRestoredFlow(): P {
val newNode = mockNet.restartNode(this)
newNode.internals.acceptableLiveFiberCountOnStop = 1
mockNet.runNetwork()
return newNode.getSingleFlow<P>().first
}
private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> {
return mockNet.messagingNetwork.receivedMessages.toSessionTransfers()
}
//endregion Helpers
}

View File

@ -40,16 +40,13 @@ import org.assertj.core.api.Assertions.assertThatThrownBy
import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType
import org.junit.After
import org.junit.Before
import org.junit.Ignore
import org.junit.Test
import rx.Notification
import rx.Observable
import java.time.Instant
import java.util.*
import kotlin.reflect.KClass
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
class FlowFrameworkTests {
companion object {
@ -449,320 +446,142 @@ class FlowFrameworkTests {
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
private fun TestStartedNode.sendSessionMessage(message: SessionMessage, destination: Party) {
services.networkService.apply {
val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList()))
send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address)
}
}
private fun assertSessionTransfers(vararg expected: SessionTransfer) {
assertThat(receivedSessionMessages).containsExactly(*expected)
}
//endregion Helpers
}
private val FlowLogic<*>.progressSteps: CordaFuture<List<Notification<ProgressTracker.Step>>>
get() {
return progressTracker!!.changes
.ofType(Change.Position::class.java)
.map { it.newStep }
.materialize()
.toList()
.toFuture()
}
class FlowFrameworkTripartyTests {
@InitiatingFlow
private class WaitForOtherSideEndBeforeSendAndReceive(val otherParty: Party,
@Transient val receivedOtherFlowEnd: Semaphore) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
// Kick off the flow on the other side ...
val session = initiateFlow(otherParty)
session.send(1)
// ... then pause this one until it's received the session-end message from the other side
receivedOtherFlowEnd.acquire()
session.sendAndReceive<Int>(2)
}
}
companion object {
// we need brand new class for a flow to fail, so here it is
@InitiatingFlow
private open class NeverRegisteredFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic<FlowInfo>() {
init {
LogHelper.setLevel("+net.corda.flow")
require(otherParties.isNotEmpty())
}
private lateinit var mockNet: InternalMockNetwork
private lateinit var aliceNode: TestStartedNode
private lateinit var bobNode: TestStartedNode
private lateinit var charlieNode: TestStartedNode
private lateinit var alice: Party
private lateinit var bob: Party
private lateinit var charlie: Party
private lateinit var notaryIdentity: Party
private val receivedSessionMessages = ArrayList<SessionTransfer>()
}
@Before
fun setUpGlobalMockNet() {
mockNet = InternalMockNetwork(
cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"),
servicePeerAllocationStrategy = RoundRobin()
)
aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME))
bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME))
charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME))
// Extract identities
alice = aliceNode.info.singleIdentity()
bob = bobNode.info.singleIdentity()
charlie = charlieNode.info.singleIdentity()
notaryIdentity = mockNet.defaultNotaryIdentity
receivedSessionMessagesObservable().forEach { receivedSessionMessages += it }
}
@After
fun cleanUp() {
mockNet.stopNodes()
receivedSessionMessages.clear()
}
private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> {
return mockNet.messagingNetwork.receivedMessages.toSessionTransfers()
}
@Test
fun `sending to multiple parties`() {
bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
charlieNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
val payload = "Hello World"
aliceNode.services.startFlow(SendFlow(payload, bob, charlie))
mockNet.runNetwork()
bobNode.internals.acceptableLiveFiberCountOnStop = 1
charlieNode.internals.acceptableLiveFiberCountOnStop = 1
val bobFlow = bobNode.getSingleFlow<InitiatedReceiveFlow>().first
val charlieFlow = charlieNode.getSingleFlow<InitiatedReceiveFlow>().first
assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload)
assertThat(charlieFlow.receivedPayloads[0]).isEqualTo(payload)
assertSessionTransfers(bobNode,
aliceNode sent sessionInit(SendFlow::class, payload = payload) to bobNode,
bobNode sent sessionConfirm() to aliceNode,
aliceNode sent normalEnd to bobNode
//There's no session end from the other flows as they're manually suspended
)
assertSessionTransfers(charlieNode,
aliceNode sent sessionInit(SendFlow::class, payload = payload) to charlieNode,
charlieNode sent sessionConfirm() to aliceNode,
aliceNode sent normalEnd to charlieNode
//There's no session end from the other flows as they're manually suspended
)
}
@Test
fun `receiving from multiple parties`() {
val bobPayload = "Test 1"
val charliePayload = "Test 2"
bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) }
charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) }
val multiReceiveFlow = ReceiveFlow(bob, charlie).nonTerminating()
aliceNode.services.startFlow(multiReceiveFlow)
aliceNode.internals.acceptableLiveFiberCountOnStop = 1
mockNet.runNetwork()
assertThat(multiReceiveFlow.receivedPayloads[0]).isEqualTo(bobPayload)
assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(charliePayload)
assertSessionTransfers(bobNode,
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode,
bobNode sent sessionData(bobPayload) to aliceNode,
bobNode sent normalEnd to aliceNode
)
assertSessionTransfers(charlieNode,
aliceNode sent sessionInit(ReceiveFlow::class) to charlieNode,
charlieNode sent sessionConfirm() to aliceNode,
charlieNode sent sessionData(charliePayload) to aliceNode,
charlieNode sent normalEnd to aliceNode
)
}
@Test
fun `FlowException only propagated to parent`() {
charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } }
bobNode.registerCordappFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) }
val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob))
mockNet.runNetwork()
assertThatExceptionOfType(UnexpectedFlowEndException::class.java)
.isThrownBy { receivingFiber.resultFuture.getOrThrow() }
}
@Test
fun `FlowException thrown and there is a 3rd unrelated party flow`() {
// Bob will send its payload and then block waiting for the receive from Alice. Meanwhile Alice will move
// onto Charlie which will throw the exception
val node2Fiber = bobNode
.registerCordappFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") }
.map { it.stateMachine }
charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } }
val aliceFiber = aliceNode.services.startFlow(ReceiveFlow(bob, charlie)) as FlowStateMachineImpl
mockNet.runNetwork()
// Alice will terminate with the error it received from Charlie but it won't propagate that to Bob (as it's
// not relevant to it) but it will end its session with it
assertThatExceptionOfType(MyFlowException::class.java).isThrownBy {
aliceFiber.resultFuture.getOrThrow()
}
val bobResultFuture = node2Fiber.getOrThrow().resultFuture
assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy {
bobResultFuture.getOrThrow()
}
assertSessionTransfers(bobNode,
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode,
bobNode sent sessionData("Hello") to aliceNode,
aliceNode sent errorMessage() to bobNode
)
}
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
private fun assertSessionTransfers(node: TestStartedNode, vararg expected: SessionTransfer): List<SessionTransfer> {
val actualForNode = receivedSessionMessages.filter { it.from == node.internals.id || it.to == node.network.myAddress }
assertThat(actualForNode).containsExactly(*expected)
return actualForNode
}
}
class FlowFrameworkPersistenceTests {
companion object {
init {
LogHelper.setLevel("+net.corda.flow")
@Suspendable
override fun call(): FlowInfo {
val flowInfos = otherParties.map {
val session = initiateFlow(it)
session.send(payload)
session.getCounterpartyFlowInfo()
}.toList()
return flowInfos.first()
}
}
private lateinit var mockNet: InternalMockNetwork
private val receivedSessionMessages = ArrayList<SessionTransfer>()
private lateinit var aliceNode: TestStartedNode
private lateinit var bobNode: TestStartedNode
private lateinit var notaryIdentity: Party
private lateinit var alice: Party
private lateinit var bob: Party
private lateinit var aliceFlowManager: MockNodeFlowManager
private lateinit var bobFlowManager: MockNodeFlowManager
@Before
fun start() {
mockNet = InternalMockNetwork(
cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"),
servicePeerAllocationStrategy = RoundRobin()
)
aliceFlowManager = MockNodeFlowManager()
bobFlowManager = MockNodeFlowManager()
aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME, flowManager = aliceFlowManager))
bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME, flowManager = bobFlowManager))
receivedSessionMessagesObservable().forEach { receivedSessionMessages += it }
// Extract identities
alice = aliceNode.info.singleIdentity()
bob = bobNode.info.singleIdentity()
notaryIdentity = mockNet.defaultNotaryIdentity
}
@After
fun cleanUp() {
mockNet.stopNodes()
receivedSessionMessages.clear()
}
@Test
fun `newly added flow is preserved on restart`() {
aliceNode.services.startFlow(NoOpFlow(nonTerminating = true))
aliceNode.internals.acceptableLiveFiberCountOnStop = 1
val restoredFlow = aliceNode.restartAndGetRestoredFlow<NoOpFlow>()
assertThat(restoredFlow.flowStarted).isTrue()
}
@Test
fun `flow restarted just after receiving payload`() {
bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it).nonTerminating() }
aliceNode.services.startFlow(SendFlow("Hello", bob))
// We push through just enough messages to get only the payload sent
bobNode.pumpReceive()
bobNode.internals.disableDBCloseOnStop()
bobNode.internals.acceptableLiveFiberCountOnStop = 1
bobNode.dispose()
mockNet.runNetwork()
val restoredFlow = bobNode.restartAndGetRestoredFlow<InitiatedReceiveFlow>()
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
}
@Test
fun `flow loaded from checkpoint will respond to messages from before start`() {
aliceNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow
val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>()
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
}
@Ignore("Some changes in startup order make this test's assumptions fail.")
@Test
fun `flow with send will resend on interrupted restart`() {
val payload = random63BitValue()
val payload2 = random63BitValue()
var sentCount = 0
mockNet.messagingNetwork.sentMessages.toSessionTransfers().filter { it.isPayloadTransfer }.forEach { sentCount++ }
val charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME))
val secondFlow = charlieNode.registerCordappFlowFactory(PingPongFlow::class) { PingPongFlow(it, payload2) }
mockNet.runNetwork()
val charlie = charlieNode.info.singleIdentity()
// Kick off first send and receive
bobNode.services.startFlow(PingPongFlow(charlie, payload))
bobNode.database.transaction {
assertEquals(1, bobNode.internals.checkpointStorage.checkpoints().size)
private object WaitingFlows {
@InitiatingFlow
class Waiter(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<SignedTransaction>() {
@Suspendable
override fun call(): SignedTransaction {
val otherPartySession = initiateFlow(otherParty)
otherPartySession.send(stx)
return waitForLedgerCommit(stx.id)
}
}
// Make sure the add() has finished initial processing.
bobNode.internals.disableDBCloseOnStop()
// Restart node and thus reload the checkpoint and resend the message with same UUID
bobNode.dispose()
bobNode.database.transaction {
assertEquals(1, bobNode.internals.checkpointStorage.checkpoints().size) // confirm checkpoint
bobNode.services.networkMapCache.clearNetworkMapCache()
}
val node2b = mockNet.createNode(InternalMockNodeParameters(bobNode.internals.id))
bobNode.internals.manuallyCloseDB()
val (firstAgain, fut1) = node2b.getSingleFlow<PingPongFlow>()
// Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync.
mockNet.runNetwork()
fut1.getOrThrow()
val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer }
// Check flows completed cleanly and didn't get out of phase
assertEquals(4, receivedCount, "Flow should have exchanged 4 unique messages")// Two messages each way
// can't give a precise value as every addMessageHandler re-runs the undelivered messages
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages")
node2b.database.transaction {
assertEquals(0, node2b.internals.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended")
class Committer(val otherPartySession: FlowSession, val throwException: (() -> Exception)? = null) : FlowLogic<SignedTransaction>() {
@Suspendable
override fun call(): SignedTransaction {
val stx = otherPartySession.receive<SignedTransaction>().unwrap { it }
if (throwException != null) throw throwException.invoke()
return subFlow(FinalityFlow(stx, setOf(otherPartySession.counterparty)))
}
}
charlieNode.database.transaction {
assertEquals(0, charlieNode.internals.checkpointStorage.checkpoints().size, "Checkpoints left after restored flow should have ended")
}
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
assertEquals(payload, secondFlow.getOrThrow().receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
assertEquals(payload + 1, secondFlow.getOrThrow().receivedPayload2, "Received payload does not match the expected second value on Node 2")
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////
//region Helpers
private inline fun <reified P : FlowLogic<*>> TestStartedNode.restartAndGetRestoredFlow(): P {
val newNode = mockNet.restartNode(this)
newNode.internals.acceptableLiveFiberCountOnStop = 1
mockNet.runNetwork()
return newNode.getSingleFlow<P>().first
private class LazyServiceHubAccessFlow : FlowLogic<Unit>() {
val lazyTime: Instant by lazy { serviceHub.clock.instant() }
@Suspendable
override fun call() = Unit
}
private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> {
return mockNet.messagingNetwork.receivedMessages.toSessionTransfers()
private interface CustomInterface
private class CustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty)
@InitiatingFlow
private class IncorrectCustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty)
@InitiatingFlow
private class VaultQueryFlow(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<List<StateAndRef<ContractState>>>() {
@Suspendable
override fun call(): List<StateAndRef<ContractState>> {
val otherPartySession = initiateFlow(otherParty)
otherPartySession.send(stx)
// hold onto reference here to force checkpoint of vaultService and thus
// prove it is registered as a tokenizableService in the node
val vaultQuerySvc = serviceHub.vaultService
waitForLedgerCommit(stx.id)
return vaultQuerySvc.queryBy<ContractState>().states
}
}
@InitiatingFlow(version = 2)
private class UpgradedFlow(val otherParty: Party, val otherPartySession: FlowSession? = null) : FlowLogic<Pair<Any, Int>>() {
constructor(otherPartySession: FlowSession) : this(otherPartySession.counterparty, otherPartySession)
@Suspendable
override fun call(): Pair<Any, Int> {
val otherPartySession = this.otherPartySession ?: initiateFlow(otherParty)
val received = otherPartySession.receive<Any>().unwrap { it }
val otherFlowVersion = otherPartySession.getCounterpartyFlowInfo().flowVersion
return Pair(received, otherFlowVersion)
}
}
private class SingleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val payload = otherPartySession.receive<String>().unwrap { it }
subFlow(InlinedSendFlow(payload + payload, otherPartySession))
}
}
private class DoubleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
subFlow(SingleInlinedSubFlow(otherPartySession))
}
}
private data class NonSerialisableData(val a: Int)
private class NonSerialisableFlowException(@Suppress("unused") val data: NonSerialisableData) : FlowException()
private class InlinedSendFlow(val payload: String, val otherPartySession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() = otherPartySession.send(payload)
}
//endregion Helpers
}
private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, "")))
internal fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, "")))
private inline fun <reified P : FlowLogic<*>> TestStartedNode.getSingleFlow(): Pair<P, CordaFuture<*>> {
internal inline fun <reified P : FlowLogic<*>> TestStartedNode.getSingleFlow(): Pair<P, CordaFuture<*>> {
return smm.findStateMachines(P::class.java).single()
}
@ -786,7 +605,7 @@ private fun sanitise(message: SessionMessage) = when (message) {
}
}
private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
internal fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
return filter { it.getMessage().topic == FlowMessagingImpl.sessionTopic }.map {
val from = it.sender.id
val message = it.messageData.deserialize<SessionMessage>()
@ -794,12 +613,19 @@ private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<Session
}
}
private fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0))
internal fun TestStartedNode.sendSessionMessage(message: SessionMessage, destination: Party) {
services.networkService.apply {
val address = getAddressOfParty(PartyInfo.SingleNode(destination, emptyList()))
send(createMessage(FlowMessagingImpl.sessionTopic, message.serialize().bytes), address)
}
}
private infix fun TestStartedNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message)
private infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer = SessionTransfer(first, second, node.network.myAddress)
internal fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0))
private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) {
internal infix fun TestStartedNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message)
internal infix fun Pair<Int, SessionMessage>.to(node: TestStartedNode): SessionTransfer = SessionTransfer(first, second, node.network.myAddress)
internal data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) {
val isPayloadTransfer: Boolean
get() =
message is ExistingSessionMessage && message.payload is DataSessionMessage ||
@ -808,40 +634,14 @@ private data class SessionTransfer(val from: Int, val message: SessionMessage, v
override fun toString(): String = "$from sent $message to $to"
}
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
internal fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
}
private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize()))
private val FlowLogic<*>.progressSteps: CordaFuture<List<Notification<ProgressTracker.Step>>>
get() {
return progressTracker!!.changes
.ofType(Change.Position::class.java)
.map { it.newStep }
.materialize()
.toList()
.toFuture()
}
internal fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize()))
@InitiatingFlow
private class WaitForOtherSideEndBeforeSendAndReceive(val otherParty: Party,
@Transient val receivedOtherFlowEnd: Semaphore) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
// Kick off the flow on the other side ...
val session = initiateFlow(otherParty)
session.send(1)
// ... then pause this one until it's received the session-end message from the other side
receivedOtherFlowEnd.acquire()
session.sendAndReceive<Int>(2)
}
}
@InitiatingFlow
private open class SendFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic<FlowInfo>() {
internal open class SendFlow(private val payload: Any, private vararg val otherParties: Party) : FlowLogic<FlowInfo>() {
init {
require(otherParties.isNotEmpty())
}
@ -857,46 +657,7 @@ private open class SendFlow(val payload: Any, vararg val otherParties: Party) :
}
}
// we need brand new class for a flow to fail, so here it is
@InitiatingFlow
private open class NeverRegisteredFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic<FlowInfo>() {
init {
require(otherParties.isNotEmpty())
}
@Suspendable
override fun call(): FlowInfo {
val flowInfos = otherParties.map {
val session = initiateFlow(it)
session.send(payload)
session.getCounterpartyFlowInfo()
}.toList()
return flowInfos.first()
}
}
private object WaitingFlows {
@InitiatingFlow
class Waiter(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<SignedTransaction>() {
@Suspendable
override fun call(): SignedTransaction {
val otherPartySession = initiateFlow(otherParty)
otherPartySession.send(stx)
return waitForLedgerCommit(stx.id)
}
}
class Committer(val otherPartySession: FlowSession, val throwException: (() -> Exception)? = null) : FlowLogic<SignedTransaction>() {
@Suspendable
override fun call(): SignedTransaction {
val stx = otherPartySession.receive<SignedTransaction>().unwrap { it }
if (throwException != null) throw throwException.invoke()
return subFlow(FinalityFlow(stx, setOf(otherPartySession.counterparty)))
}
}
}
private class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic<Unit>() {
internal class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic<Unit>() {
@Transient
var flowStarted = false
@ -909,7 +670,7 @@ private class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic<Unit>()
}
}
private class InitiatedReceiveFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() {
internal class InitiatedReceiveFlow(private val otherPartySession: FlowSession) : FlowLogic<Unit>() {
object START_STEP : ProgressTracker.Step("Starting")
object RECEIVED_STEP : ProgressTracker.Step("Received")
@ -934,26 +695,13 @@ private class InitiatedReceiveFlow(val otherPartySession: FlowSession) : FlowLog
}
}
private class LazyServiceHubAccessFlow : FlowLogic<Unit>() {
val lazyTime: Instant by lazy { serviceHub.clock.instant() }
@Suspendable
override fun call() = Unit
}
private open class InitiatedSendFlow(val payload: Any, val otherPartySession: FlowSession) : FlowLogic<Unit>() {
internal open class InitiatedSendFlow(private val payload: Any, private val otherPartySession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() = otherPartySession.send(payload)
}
private interface CustomInterface
private class CustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty)
@InitiatingFlow
private class IncorrectCustomSendFlow(payload: String, otherParty: Party) : CustomInterface, SendFlow(payload, otherParty)
@InitiatingFlow
private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
internal class ReceiveFlow(private vararg val otherParties: Party) : FlowLogic<Unit>() {
object START_STEP : ProgressTracker.Step("Starting")
object RECEIVED_STEP : ProgressTracker.Step("Received")
@ -982,72 +730,23 @@ private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
}
}
private class MyFlowException(override val message: String) : FlowException() {
internal class MyFlowException(override val message: String) : FlowException() {
override fun equals(other: Any?): Boolean = other is MyFlowException && other.message == this.message
override fun hashCode(): Int = message.hashCode()
}
@InitiatingFlow
private class VaultQueryFlow(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<List<StateAndRef<ContractState>>>() {
@Suspendable
override fun call(): List<StateAndRef<ContractState>> {
val otherPartySession = initiateFlow(otherParty)
otherPartySession.send(stx)
// hold onto reference here to force checkpoint of vaultService and thus
// prove it is registered as a tokenizableService in the node
val vaultQuerySvc = serviceHub.vaultService
waitForLedgerCommit(stx.id)
return vaultQuerySvc.queryBy<ContractState>().states
}
}
@InitiatingFlow(version = 2)
private class UpgradedFlow(val otherParty: Party, val otherPartySession: FlowSession? = null) : FlowLogic<Pair<Any, Int>>() {
constructor(otherPartySession: FlowSession) : this(otherPartySession.counterparty, otherPartySession)
@Suspendable
override fun call(): Pair<Any, Int> {
val otherPartySession = this.otherPartySession ?: initiateFlow(otherParty)
val received = otherPartySession.receive<Any>().unwrap { it }
val otherFlowVersion = otherPartySession.getCounterpartyFlowInfo().flowVersion
return Pair(received, otherFlowVersion)
}
}
private class SingleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val payload = otherPartySession.receive<String>().unwrap { it }
subFlow(InlinedSendFlow(payload + payload, otherPartySession))
}
}
private class DoubleInlinedSubFlow(val otherPartySession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
subFlow(SingleInlinedSubFlow(otherPartySession))
}
}
private data class NonSerialisableData(val a: Int)
private class NonSerialisableFlowException(@Suppress("unused") val data: NonSerialisableData) : FlowException()
@InitiatingFlow
private class SendAndReceiveFlow(val otherParty: Party, val payload: Any, val otherPartySession: FlowSession? = null) : FlowLogic<Any>() {
internal class SendAndReceiveFlow(private val otherParty: Party, private val payload: Any, private val otherPartySession: FlowSession? = null) : FlowLogic<Any>() {
constructor(otherPartySession: FlowSession, payload: Any) : this(otherPartySession.counterparty, payload, otherPartySession)
@Suspendable
override fun call(): Any = (otherPartySession
?: initiateFlow(otherParty)).sendAndReceive<Any>(payload).unwrap { it }
}
private class InlinedSendFlow(val payload: String, val otherPartySession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() = otherPartySession.send(payload)
override fun call(): Any {
return (otherPartySession ?: initiateFlow(otherParty)).sendAndReceive<Any>(payload).unwrap { it }
}
}
@InitiatingFlow
private class PingPongFlow(val otherParty: Party, val payload: Long, val otherPartySession: FlowSession? = null) : FlowLogic<Unit>() {
internal class PingPongFlow(private val otherParty: Party, private val payload: Long, private val otherPartySession: FlowSession? = null) : FlowLogic<Unit>() {
constructor(otherPartySession: FlowSession, payload: Long) : this(otherPartySession.counterparty, payload, otherPartySession)
@Transient
@ -1063,7 +762,7 @@ private class PingPongFlow(val otherParty: Party, val payload: Long, val otherPa
}
}
private class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic<Nothing>() {
internal class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic<Nothing>() {
object START_STEP : ProgressTracker.Step("Starting")
override val progressTracker: ProgressTracker = ProgressTracker(START_STEP)
@ -1075,4 +774,4 @@ private class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic<N
exceptionThrown = exception()
throw exceptionThrown
}
}
}

View File

@ -0,0 +1,178 @@
package net.corda.node.services.statemachine
import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.flows.registerCordappFlowFactory
import net.corda.core.identity.Party
import net.corda.core.internal.concurrent.map
import net.corda.core.utilities.getOrThrow
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.CHARLIE_NAME
import net.corda.testing.core.singleIdentity
import net.corda.testing.internal.LogHelper
import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.internal.*
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.AssertionsForClassTypes
import org.junit.After
import org.junit.Before
import org.junit.Test
import rx.Observable
import java.util.*
class FlowFrameworkTripartyTests {
companion object {
init {
LogHelper.setLevel("+net.corda.flow")
}
private lateinit var mockNet: InternalMockNetwork
private lateinit var aliceNode: TestStartedNode
private lateinit var bobNode: TestStartedNode
private lateinit var charlieNode: TestStartedNode
private lateinit var alice: Party
private lateinit var bob: Party
private lateinit var charlie: Party
private lateinit var notaryIdentity: Party
private val receivedSessionMessages = ArrayList<SessionTransfer>()
}
@Before
fun setUpGlobalMockNet() {
mockNet = InternalMockNetwork(
cordappsForAllNodes = cordappsForPackages("net.corda.finance.contracts", "net.corda.testing.contracts"),
servicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin()
)
aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME))
bobNode = mockNet.createNode(InternalMockNodeParameters(legalName = BOB_NAME))
charlieNode = mockNet.createNode(InternalMockNodeParameters(legalName = CHARLIE_NAME))
// Extract identities
alice = aliceNode.info.singleIdentity()
bob = bobNode.info.singleIdentity()
charlie = charlieNode.info.singleIdentity()
notaryIdentity = mockNet.defaultNotaryIdentity
receivedSessionMessagesObservable().forEach { receivedSessionMessages += it }
}
@After
fun cleanUp() {
mockNet.stopNodes()
receivedSessionMessages.clear()
}
private fun receivedSessionMessagesObservable(): Observable<SessionTransfer> {
return mockNet.messagingNetwork.receivedMessages.toSessionTransfers()
}
@Test
fun `sending to multiple parties`() {
bobNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it)
.nonTerminating() }
charlieNode.registerCordappFlowFactory(SendFlow::class) { InitiatedReceiveFlow(it)
.nonTerminating() }
val payload = "Hello World"
aliceNode.services.startFlow(SendFlow(payload, bob, charlie))
mockNet.runNetwork()
bobNode.internals.acceptableLiveFiberCountOnStop = 1
charlieNode.internals.acceptableLiveFiberCountOnStop = 1
val bobFlow = bobNode.getSingleFlow<InitiatedReceiveFlow>().first
val charlieFlow = charlieNode.getSingleFlow<InitiatedReceiveFlow>().first
assertThat(bobFlow.receivedPayloads[0]).isEqualTo(payload)
assertThat(charlieFlow.receivedPayloads[0]).isEqualTo(payload)
assertSessionTransfers(bobNode,
aliceNode sent sessionInit(SendFlow::class, payload = payload) to bobNode,
bobNode sent sessionConfirm() to aliceNode,
aliceNode sent normalEnd to bobNode
//There's no session end from the other flows as they're manually suspended
)
assertSessionTransfers(charlieNode,
aliceNode sent sessionInit(SendFlow::class, payload = payload) to charlieNode,
charlieNode sent sessionConfirm() to aliceNode,
aliceNode sent normalEnd to charlieNode
//There's no session end from the other flows as they're manually suspended
)
}
@Test
fun `receiving from multiple parties`() {
val bobPayload = "Test 1"
val charliePayload = "Test 2"
bobNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(bobPayload, it) }
charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { InitiatedSendFlow(charliePayload, it) }
val multiReceiveFlow = ReceiveFlow(bob, charlie).nonTerminating()
aliceNode.services.startFlow(multiReceiveFlow)
aliceNode.internals.acceptableLiveFiberCountOnStop = 1
mockNet.runNetwork()
assertThat(multiReceiveFlow.receivedPayloads[0]).isEqualTo(bobPayload)
assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(charliePayload)
assertSessionTransfers(bobNode,
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode,
bobNode sent sessionData(bobPayload) to aliceNode,
bobNode sent normalEnd to aliceNode
)
assertSessionTransfers(charlieNode,
aliceNode sent sessionInit(ReceiveFlow::class) to charlieNode,
charlieNode sent sessionConfirm() to aliceNode,
charlieNode sent sessionData(charliePayload) to aliceNode,
charlieNode sent normalEnd to aliceNode
)
}
@Test
fun `FlowException only propagated to parent`() {
charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } }
bobNode.registerCordappFlowFactory(ReceiveFlow::class) { ReceiveFlow(charlie) }
val receivingFiber = aliceNode.services.startFlow(ReceiveFlow(bob))
mockNet.runNetwork()
AssertionsForClassTypes.assertThatExceptionOfType(UnexpectedFlowEndException::class.java)
.isThrownBy { receivingFiber.resultFuture.getOrThrow() }
}
@Test
fun `FlowException thrown and there is a 3rd unrelated party flow`() {
// Bob will send its payload and then block waiting for the receive from Alice. Meanwhile Alice will move
// onto Charlie which will throw the exception
val node2Fiber = bobNode
.registerCordappFlowFactory(ReceiveFlow::class) { SendAndReceiveFlow(it, "Hello") }
.map { it.stateMachine }
charlieNode.registerCordappFlowFactory(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Nothing useful") } }
val aliceFiber = aliceNode.services.startFlow(ReceiveFlow(bob, charlie)) as FlowStateMachineImpl
mockNet.runNetwork()
// Alice will terminate with the error it received from Charlie but it won't propagate that to Bob (as it's
// not relevant to it) but it will end its session with it
AssertionsForClassTypes.assertThatExceptionOfType(MyFlowException::class.java)
.isThrownBy {
aliceFiber.resultFuture.getOrThrow()
}
val bobResultFuture = node2Fiber.getOrThrow().resultFuture
AssertionsForClassTypes.assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy {
bobResultFuture.getOrThrow()
}
assertSessionTransfers(bobNode,
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode,
bobNode sent sessionData("Hello") to aliceNode,
aliceNode sent errorMessage() to bobNode
)
}
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
private fun assertSessionTransfers(node: TestStartedNode, vararg expected: SessionTransfer): List<SessionTransfer> {
val actualForNode = receivedSessionMessages.filter { it.from == node.internals.id || it.to == node.network.myAddress }
assertThat(actualForNode).containsExactly(*expected)
return actualForNode
}
}