From 29a101c3780a7bde1ccec60ea6425796899058a1 Mon Sep 17 00:00:00 2001 From: Michele Sollecito Date: Mon, 9 Oct 2017 13:46:37 +0100 Subject: [PATCH] [CORDA-683] Enable `receiveAll()` from Flows. --- .gitignore | 1 + .../kotlin/net/corda/core/flows/FlowLogic.kt | 48 +++++++- .../corda/core/internal/FlowStateMachine.kt | 11 +- .../net/corda/core/flows/FlowTestsUtils.kt | 115 ++++++++++++++++++ .../corda/core/flows/ReceiveAllFlowTests.kt | 87 +++++++++++++ docs/source/changelog.rst | 1 + .../services/statemachine/FlowIORequest.kt | 66 +++++++++- .../statemachine/FlowSessionInternal.kt | 2 + .../statemachine/FlowStateMachineImpl.kt | 34 +++++- .../statemachine/StateMachineManager.kt | 10 +- .../kotlin/net/corda/testing/node/MockNode.kt | 17 ++- 11 files changed, 377 insertions(+), 15 deletions(-) create mode 100644 core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt create mode 100644 core/src/test/kotlin/net/corda/core/flows/ReceiveAllFlowTests.kt diff --git a/.gitignore b/.gitignore index 645707dcfb..ad05b9958f 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,7 @@ lib/quasar.jar .idea/shelf .idea/dataSources .idea/markdown-navigator +.idea/runConfigurations /gradle-plugins/.idea/ # Include the -parameters compiler option by default in IntelliJ required for serialization. diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index fd2c144d9e..419c918753 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -6,6 +6,7 @@ import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.abbreviate +import net.corda.core.internal.uncheckedCast import net.corda.core.messaging.DataFeed import net.corda.core.node.NodeInfo import net.corda.core.node.ServiceHub @@ -177,6 +178,38 @@ abstract class FlowLogic { return stateMachine.receive(receiveType, otherParty, flowUsedForSessions) } + /** Suspends until a message has been received for each session in the specified [sessions]. + * + * Consider [receiveAll(receiveType: Class, sessions: List): List>] when the same type is expected from all sessions. + * + * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly + * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly + * corrupted data in order to exploit your code. + * + * @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them. + */ + @Suspendable + open fun receiveAll(sessions: Map>): Map> { + return stateMachine.receiveAll(sessions, this) + } + + /** + * Suspends until a message has been received for each session in the specified [sessions]. + * + * Consider [sessions: Map>): Map>] when sessions are expected to receive different types. + * + * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly + * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly + * corrupted data in order to exploit your code. + * + * @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions]. + */ + @Suspendable + open fun receiveAll(receiveType: Class, sessions: List): List> { + enforceNoDuplicates(sessions) + return castMapValuesToKnownType(receiveAll(associateSessionsToReceiveType(receiveType, sessions))) + } + /** * Queues the given [payload] for sending to the [otherParty] and continues without suspending. * @@ -231,7 +264,6 @@ abstract class FlowLogic { stateMachine.checkFlowPermission(permissionName, extraAuditData) } - /** * Flows can call this method to record application level flow audit events * @param eventType is a string representing the type of event. Each flow is given a distinct namespace for these names. @@ -334,6 +366,18 @@ abstract class FlowLogic { ours.setChildProgressTracker(ours.currentStep, theirs) } } + + private fun enforceNoDuplicates(sessions: List) { + require(sessions.size == sessions.toSet().size) { "A flow session can only appear once as argument." } + } + + private fun associateSessionsToReceiveType(receiveType: Class, sessions: List): Map> { + return sessions.associateByTo(LinkedHashMap(), { it }, { receiveType }) + } + + private fun castMapValuesToKnownType(map: Map>): List> { + return map.values.map { uncheckedCast>(it) } + } } /** @@ -351,4 +395,4 @@ data class FlowInfo( * to deduplicate it from other releases of the same CorDapp, typically a version string. See the * [CorDapp JAR format](https://docs.corda.net/cordapp-build-systems.html#cordapp-jar-format) for more details. */ - val appName: String) + val appName: String) \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt index dc09d6da3f..261af49f02 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt @@ -30,20 +30,20 @@ interface FlowStateMachine { fun receive(receiveType: Class, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData @Suspendable - fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>): Unit + fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) @Suspendable fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction - fun checkFlowPermission(permissionName: String, extraAuditData: Map): Unit + fun checkFlowPermission(permissionName: String, extraAuditData: Map) - fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map): Unit + fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map) @Suspendable fun flowStackSnapshot(flowClass: Class>): FlowStackSnapshot? @Suspendable - fun persistFlowStackSnapshot(flowClass: Class>): Unit + fun persistFlowStackSnapshot(flowClass: Class>) val serviceHub: ServiceHub val logger: Logger @@ -51,4 +51,7 @@ interface FlowStateMachine { val resultFuture: CordaFuture val flowInitiator: FlowInitiator val ourIdentityAndCert: PartyAndCertificate + + @Suspendable + fun receiveAll(sessions: Map>, sessionFlow: FlowLogic<*>): Map> } diff --git a/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt b/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt new file mode 100644 index 0000000000..6b6f0492b3 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt @@ -0,0 +1,115 @@ +package net.corda.core.flows + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.utilities.UntrustworthyData +import net.corda.core.utilities.unwrap +import net.corda.node.internal.InitiatedFlowFactory +import net.corda.node.internal.StartedNode +import kotlin.reflect.KClass + +/** + * Allows to simplify writing flows that simply rend a message back to an initiating flow. + */ +class Answer(session: FlowSession, override val answer: R, closure: (result: R) -> Unit = {}) : SimpleAnswer(session, closure) + +/** + * Allows to simplify writing flows that simply rend a message back to an initiating flow. + */ +abstract class SimpleAnswer(private val session: FlowSession, private val closure: (result: R) -> Unit = {}) : FlowLogic() { + @Suspendable + override fun call() { + val tmp = answer + closure(tmp) + session.send(tmp) + } + + protected abstract val answer: R +} + +/** + * A flow that does not do anything when triggered. + */ +class NoAnswer(private val closure: () -> Unit = {}) : FlowLogic() { + @Suspendable + override fun call() = closure() +} + +/** + * Allows to register a flow of type [R] against an initiating flow of type [I]. + */ +inline fun , reified R : FlowLogic<*>> StartedNode<*>.registerInitiatedFlow(initiatingFlowType: KClass, crossinline construct: (session: FlowSession) -> R) { + internals.internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> construct(session) }, R::class.javaObjectType, true) +} + +/** + * Allows to register a flow of type [Answer] against an initiating flow of type [I], returning a valure of type [R]. + */ +inline fun , reified R : Any> StartedNode<*>.registerAnswer(initiatingFlowType: KClass, value: R) { + internals.internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> Answer(session, value) }, Answer::class.javaObjectType, true) +} + +/** + * Extracts data from a [Map[FlowSession, UntrustworthyData]] without performing checks and casting to [R]. + */ +@Suppress("UNCHECKED_CAST") +infix fun Map>.from(session: FlowSession): R = this[session]!!.unwrap { it as R } + +/** + * Creates a [Pair([session], [Class])] from this [Class]. + */ +infix fun > T.from(session: FlowSession): Pair = session to this + +/** + * Creates a [Pair([session], [Class])] from this [KClass]. + */ +infix fun KClass.from(session: FlowSession): Pair> = session to this.javaObjectType + +/** + * Suspends until a message has been received for each session in the specified [sessions]. + * + * Consider [receiveAll(receiveType: Class, sessions: List): List>] when the same type is expected from all sessions. + * + * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly + * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly + * corrupted data in order to exploit your code. + * + * @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them. + */ +@Suspendable +fun FlowLogic<*>.receiveAll(session: Pair>, vararg sessions: Pair>): Map> { + val allSessions = arrayOf(session, *sessions) + allSessions.enforceNoDuplicates() + return receiveAll(mapOf(*allSessions)) +} + +/** + * Suspends until a message has been received for each session in the specified [sessions]. + * + * Consider [sessions: Map>): Map>] when sessions are expected to receive different types. + * + * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly + * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly + * corrupted data in order to exploit your code. + * + * @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions]. + */ +@Suspendable +fun FlowLogic<*>.receiveAll(receiveType: Class, session: FlowSession, vararg sessions: FlowSession): List> = receiveAll(receiveType, listOf(session, *sessions)) + +/** + * Suspends until a message has been received for each session in the specified [sessions]. + * + * Consider [sessions: Map>): Map>] when sessions are expected to receive different types. + * + * Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly + * verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly + * corrupted data in order to exploit your code. + * + * @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions]. + */ +@Suspendable +inline fun FlowLogic<*>.receiveAll(session: FlowSession, vararg sessions: FlowSession): List> = receiveAll(R::class.javaObjectType, listOf(session, *sessions)) + +private fun Array>>.enforceNoDuplicates() { + require(this.size == this.toSet().size) { "A flow session can only appear once as argument." } +} \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/flows/ReceiveAllFlowTests.kt b/core/src/test/kotlin/net/corda/core/flows/ReceiveAllFlowTests.kt new file mode 100644 index 0000000000..cfadcf4bf5 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/flows/ReceiveAllFlowTests.kt @@ -0,0 +1,87 @@ +package net.corda.core.flows + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.identity.Party +import net.corda.core.utilities.UntrustworthyData +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.unwrap +import net.corda.testing.chooseIdentity +import net.corda.testing.node.network +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test + +class ReceiveMultipleFlowTests { + @Test + fun `receive all messages in parallel using map style`() { + network(3) { nodes, _ -> + val doubleValue = 5.0 + nodes[1].registerAnswer(AlgorithmDefinition::class, doubleValue) + val stringValue = "Thriller" + nodes[2].registerAnswer(AlgorithmDefinition::class, stringValue) + + val flow = nodes[0].services.startFlow(ParallelAlgorithmMap(nodes[1].info.chooseIdentity(), nodes[2].info.chooseIdentity())) + runNetwork() + + val result = flow.resultFuture.getOrThrow() + + assertThat(result).isEqualTo(doubleValue * stringValue.length) + } + } + + @Test + fun `receive all messages in parallel using list style`() { + network(3) { nodes, _ -> + val value1 = 5.0 + nodes[1].registerAnswer(ParallelAlgorithmList::class, value1) + val value2 = 6.0 + nodes[2].registerAnswer(ParallelAlgorithmList::class, value2) + + val flow = nodes[0].services.startFlow(ParallelAlgorithmList(nodes[1].info.chooseIdentity(), nodes[2].info.chooseIdentity())) + runNetwork() + val data = flow.resultFuture.getOrThrow() + + assertThat(data[0]).isEqualTo(value1) + assertThat(data[1]).isEqualTo(value2) + assertThat(data.fold(1.0) { a, b -> a * b }).isEqualTo(value1 * value2) + } + } + + class ParallelAlgorithmMap(doubleMember: Party, stringMember: Party) : AlgorithmDefinition(doubleMember, stringMember) { + @Suspendable + override fun askMembersForData(doubleMember: Party, stringMember: Party): Data { + val doubleSession = initiateFlow(doubleMember) + val stringSession = initiateFlow(stringMember) + val rawData = receiveAll(Double::class from doubleSession, String::class from stringSession) + return Data(rawData from doubleSession, rawData from stringSession) + } + } + + @InitiatingFlow + class ParallelAlgorithmList(private val member1: Party, private val member2: Party) : FlowLogic>() { + @Suspendable + override fun call(): List { + val session1 = initiateFlow(member1) + val session2 = initiateFlow(member2) + val data = receiveAll(session1, session2) + return computeAnswer(data) + } + + private fun computeAnswer(data: List>): List { + return data.map { element -> element.unwrap { it } } + } + } + + @InitiatingFlow + abstract class AlgorithmDefinition(private val doubleMember: Party, private val stringMember: Party) : FlowLogic() { + protected data class Data(val double: Double, val string: String) + + @Suspendable + protected abstract fun askMembersForData(doubleMember: Party, stringMember: Party): Data + + @Suspendable + override fun call(): Double { + val (double, string) = askMembersForData(doubleMember, stringMember) + return double * string.length + } + } +} \ No newline at end of file diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index f07f2e964e..0955876373 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -6,6 +6,7 @@ from the previous milestone release. UNRELEASED ---------- +* ``FlowLogic`` now exposes a series of function called ``receiveAll(...)`` allowing to join ``receive(...)`` instructions. * ``Cordform`` and node identity generation * Cordform may not specify a value for ``NetworkMap``, when that happens, during the task execution the following happens: diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt index 748cce9bd8..cd56d786ad 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt @@ -1,5 +1,6 @@ package net.corda.node.services.statemachine +import co.paralleluniverse.fibers.Suspendable import net.corda.core.crypto.SecureHash interface FlowIORequest { @@ -8,7 +9,9 @@ interface FlowIORequest { val stackTraceInCaseOfProblems: StackSnapshot } -interface WaitingRequest : FlowIORequest +interface WaitingRequest : FlowIORequest { + fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean +} interface SessionedFlowIORequest : FlowIORequest { val session: FlowSessionInternal @@ -21,6 +24,8 @@ interface SendRequest : SessionedFlowIORequest { interface ReceiveRequest : SessionedFlowIORequest, WaitingRequest { val receiveType: Class val userReceiveType: Class<*>? + + override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session } data class SendAndReceive(override val session: FlowSessionInternal, @@ -38,6 +43,63 @@ data class ReceiveOnly(override val session: FlowSessionInte override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() } +class ReceiveAll(val requests: List>) : WaitingRequest { + @Transient + override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() + + private fun isComplete(received: LinkedHashMap): Boolean { + return received.keys == requests.map { it.session }.toSet() + } + private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) } + + private fun hasSuccessfulEndMessage(it: ReceiveRequest): Boolean { + return it.session.receivedMessages.map { it.message }.any { it is SessionData || it is SessionEnd } + } + + @Suspendable + fun suspendAndExpectReceive(suspend: Suspend): Map { + val receivedMessages = LinkedHashMap() + + poll(receivedMessages) + return if (isComplete(receivedMessages)) { + receivedMessages + } else { + suspend(this) + poll(receivedMessages) + if (isComplete(receivedMessages)) { + receivedMessages + } else { + throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a ${it.receiveType.simpleName} but instead got nothing for $it." }.joinToString { "\n" }) + } + } + } + + interface Suspend { + @Suspendable + operator fun invoke(request: FlowIORequest) + } + + @Suspendable + private fun poll(receivedMessages: LinkedHashMap) { + return requests.filter { it.session !in receivedMessages.keys }.forEach { request -> + poll(request)?.let { + receivedMessages[request.session] = RequestMessage(request, it) + } + } + } + + @Suspendable + private fun poll(request: ReceiveRequest): ReceivedSessionMessage<*>? { + return request.session.receivedMessages.poll() + } + + override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant() + + private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session } + + data class RequestMessage(val request: ReceiveRequest, val message: ReceivedSessionMessage<*>) +} + data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest { @Transient override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() @@ -46,6 +108,8 @@ data class SendOnly(override val session: FlowSessionInternal, override val mess data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest { @Transient override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() + + override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message is ErrorSessionEnd } class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem") diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt index 4f2d1ba5fc..dc5b39c6f5 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt @@ -2,6 +2,7 @@ package net.corda.node.services.statemachine import net.corda.core.flows.FlowInfo import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowSession import net.corda.core.identity.Party import net.corda.node.services.statemachine.FlowSessionState.Initiated import net.corda.node.services.statemachine.FlowSessionState.Initiating @@ -15,6 +16,7 @@ import java.util.concurrent.ConcurrentLinkedQueue // TODO rename this class FlowSessionInternal( val flow: FlowLogic<*>, + val flowSession : FlowSession, val ourSessionId: Long, val initiatingParty: Party?, var state: FlowSessionState, diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 59f40a8533..0de001283e 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -12,9 +12,13 @@ import net.corda.core.crypto.random63BitValue import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate -import net.corda.core.internal.* +import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.abbreviate import net.corda.core.internal.concurrent.OpenFuture import net.corda.core.internal.concurrent.openFuture +import net.corda.core.internal.isRegularFile +import net.corda.core.internal.staticField +import net.corda.core.internal.uncheckedCast import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.* import net.corda.node.services.api.FlowAppAuditEvent @@ -171,8 +175,8 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, "@${InitiatingFlow::class.java.simpleName} sub-flow." ) } - createNewSession(otherParty, sessionFlow) val flowSession = FlowSessionImpl(otherParty) + createNewSession(otherParty, flowSession, sessionFlow) flowSession.stateMachine = this flowSession.sessionFlow = sessionFlow return flowSession @@ -299,6 +303,22 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, FlowStackSnapshotFactory.instance.persistAsJsonFile(flowClass, serviceHub.configuration.baseDirectory, id) } + @Suspendable + override fun receiveAll(sessions: Map>, sessionFlow: FlowLogic<*>): Map> { + val requests = ArrayList>() + for ((session, receiveType) in sessions) { + val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow) + requests.add(ReceiveOnly(sessionInternal, SessionData::class.java, receiveType)) + } + val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend) + val result = LinkedHashMap>() + for ((sessionInternal, requestAndMessage) in receivedMessages) { + val message = requestAndMessage.message.confirmReceiveType(requestAndMessage.request) + result[sessionInternal.flowSession] = message.checkPayloadIs(requestAndMessage.request.userReceiveType as Class) + } + return result + } + /** * This method will suspend the state machine and wait for incoming session init response from other party. */ @@ -362,10 +382,11 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, private fun createNewSession( otherParty: Party, + flowSession: FlowSession, sessionFlow: FlowLogic<*> ) { logger.trace { "Creating a new session with $otherParty" } - val session = FlowSessionInternal(sessionFlow, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty)) + val session = FlowSessionInternal(sessionFlow, flowSession, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty)) openSessions[Pair(sessionFlow, otherParty)] = session } @@ -402,6 +423,13 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest) } + private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend { + @Suspendable + override fun invoke(request: FlowIORequest) { + suspend(request) + } + } + @Suspendable private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> { val polledMessage = session.receivedMessages.poll() diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 28e2c58958..c8189063a6 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -15,7 +15,11 @@ import net.corda.core.crypto.SecureHash import net.corda.core.crypto.random63BitValue import net.corda.core.flows.* import net.corda.core.identity.Party -import net.corda.core.internal.* +import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.ThreadBox +import net.corda.core.internal.bufferUntilSubscribed +import net.corda.core.internal.castIfPossible +import net.corda.core.internal.uncheckedCast import net.corda.core.messaging.DataFeed import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY @@ -342,8 +346,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, // commit but a counterparty flow has ended with an error (in which case our flow also has to end) private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean { val waitingForResponse = session.fiber.waitingForResponse - return (waitingForResponse as? ReceiveRequest<*>)?.session === session || - waitingForResponse is WaitForLedgerCommit && message is ErrorSessionEnd + return waitingForResponse?.shouldResume(message, session) ?: false } private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) { @@ -362,6 +365,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } val session = FlowSessionInternal( flow, + flowSession, random63BitValue(), sender, FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName))) diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt index b1f02788c6..52989475b3 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt @@ -51,6 +51,7 @@ import net.corda.testing.resetTestSerialization import net.corda.testing.testNodeConfiguration import org.apache.activemq.artemis.utils.ReusableLatch import org.slf4j.Logger +import java.io.Closeable import java.math.BigInteger import java.nio.file.Path import java.security.KeyPair @@ -81,12 +82,11 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random(), private val defaultFactory: Factory<*> = MockNetwork.DefaultFactory, - private val initialiseSerialization: Boolean = true) { + private val initialiseSerialization: Boolean = true) : Closeable { companion object { // TODO In future PR we're removing the concept of network map node so the details of this mock are not important. val MOCK_NET_MAP = Party(CordaX500Name(organisation = "Mock Network Map", locality = "Madrid", country = "ES"), DUMMY_KEY_1.public) } - var nextNodeId = 0 private set private val filesystem = Jimfs.newFileSystem(unix()) @@ -445,4 +445,17 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, fun waitQuiescent() { busyLatch.await() } + + override fun close() { + stopNodes() + } } + +fun network(nodesCount: Int, action: MockNetwork.(nodes: List>, notary: StartedNode) -> Unit) { + MockNetwork().use { + it.runNetwork() + val notary = it.createNotaryNode() + val nodes = (1..nodesCount).map { _ -> it.createPartyNode() } + action(it, nodes, notary) + } +} \ No newline at end of file