[CORDA-683] Enable receiveAll() from Flows.

This commit is contained in:
Michele Sollecito 2017-10-09 13:46:37 +01:00 committed by GitHub
parent 20a30b30da
commit 29a101c378
11 changed files with 377 additions and 15 deletions

1
.gitignore vendored
View File

@ -34,6 +34,7 @@ lib/quasar.jar
.idea/shelf .idea/shelf
.idea/dataSources .idea/dataSources
.idea/markdown-navigator .idea/markdown-navigator
.idea/runConfigurations
/gradle-plugins/.idea/ /gradle-plugins/.idea/
# Include the -parameters compiler option by default in IntelliJ required for serialization. # Include the -parameters compiler option by default in IntelliJ required for serialization.

View File

@ -6,6 +6,7 @@ import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.abbreviate import net.corda.core.internal.abbreviate
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
@ -177,6 +178,38 @@ abstract class FlowLogic<out T> {
return stateMachine.receive(receiveType, otherParty, flowUsedForSessions) return stateMachine.receive(receiveType, otherParty, flowUsedForSessions)
} }
/** Suspends until a message has been received for each session in the specified [sessions].
*
* Consider [receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>>] when the same type is expected from all sessions.
*
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
* corrupted data in order to exploit your code.
*
* @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them.
*/
@Suspendable
open fun receiveAll(sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>> {
return stateMachine.receiveAll(sessions, this)
}
/**
* Suspends until a message has been received for each session in the specified [sessions].
*
* Consider [sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>>] when sessions are expected to receive different types.
*
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
* corrupted data in order to exploit your code.
*
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions].
*/
@Suspendable
open fun <R : Any> receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>> {
enforceNoDuplicates(sessions)
return castMapValuesToKnownType(receiveAll(associateSessionsToReceiveType(receiveType, sessions)))
}
/** /**
* Queues the given [payload] for sending to the [otherParty] and continues without suspending. * Queues the given [payload] for sending to the [otherParty] and continues without suspending.
* *
@ -231,7 +264,6 @@ abstract class FlowLogic<out T> {
stateMachine.checkFlowPermission(permissionName, extraAuditData) stateMachine.checkFlowPermission(permissionName, extraAuditData)
} }
/** /**
* Flows can call this method to record application level flow audit events * Flows can call this method to record application level flow audit events
* @param eventType is a string representing the type of event. Each flow is given a distinct namespace for these names. * @param eventType is a string representing the type of event. Each flow is given a distinct namespace for these names.
@ -334,6 +366,18 @@ abstract class FlowLogic<out T> {
ours.setChildProgressTracker(ours.currentStep, theirs) ours.setChildProgressTracker(ours.currentStep, theirs)
} }
} }
private fun enforceNoDuplicates(sessions: List<FlowSession>) {
require(sessions.size == sessions.toSet().size) { "A flow session can only appear once as argument." }
}
private fun <R> associateSessionsToReceiveType(receiveType: Class<R>, sessions: List<FlowSession>): Map<FlowSession, Class<R>> {
return sessions.associateByTo(LinkedHashMap(), { it }, { receiveType })
}
private fun <R> castMapValuesToKnownType(map: Map<FlowSession, UntrustworthyData<Any>>): List<UntrustworthyData<R>> {
return map.values.map { uncheckedCast<Any, UntrustworthyData<R>>(it) }
}
} }
/** /**

View File

@ -30,20 +30,20 @@ interface FlowStateMachine<R> {
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData<T> fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData<T>
@Suspendable @Suspendable
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>): Unit fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>)
@Suspendable @Suspendable
fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction
fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>): Unit fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>)
fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String, String>): Unit fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String, String>)
@Suspendable @Suspendable
fun flowStackSnapshot(flowClass: Class<out FlowLogic<*>>): FlowStackSnapshot? fun flowStackSnapshot(flowClass: Class<out FlowLogic<*>>): FlowStackSnapshot?
@Suspendable @Suspendable
fun persistFlowStackSnapshot(flowClass: Class<out FlowLogic<*>>): Unit fun persistFlowStackSnapshot(flowClass: Class<out FlowLogic<*>>)
val serviceHub: ServiceHub val serviceHub: ServiceHub
val logger: Logger val logger: Logger
@ -51,4 +51,7 @@ interface FlowStateMachine<R> {
val resultFuture: CordaFuture<R> val resultFuture: CordaFuture<R>
val flowInitiator: FlowInitiator val flowInitiator: FlowInitiator
val ourIdentityAndCert: PartyAndCertificate val ourIdentityAndCert: PartyAndCertificate
@Suspendable
fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>>
} }

View File

@ -0,0 +1,115 @@
package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.unwrap
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.internal.StartedNode
import kotlin.reflect.KClass
/**
* Allows to simplify writing flows that simply rend a message back to an initiating flow.
*/
class Answer<out R : Any>(session: FlowSession, override val answer: R, closure: (result: R) -> Unit = {}) : SimpleAnswer<R>(session, closure)
/**
* Allows to simplify writing flows that simply rend a message back to an initiating flow.
*/
abstract class SimpleAnswer<out R : Any>(private val session: FlowSession, private val closure: (result: R) -> Unit = {}) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val tmp = answer
closure(tmp)
session.send(tmp)
}
protected abstract val answer: R
}
/**
* A flow that does not do anything when triggered.
*/
class NoAnswer(private val closure: () -> Unit = {}) : FlowLogic<Unit>() {
@Suspendable
override fun call() = closure()
}
/**
* Allows to register a flow of type [R] against an initiating flow of type [I].
*/
inline fun <I : FlowLogic<*>, reified R : FlowLogic<*>> StartedNode<*>.registerInitiatedFlow(initiatingFlowType: KClass<I>, crossinline construct: (session: FlowSession) -> R) {
internals.internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> construct(session) }, R::class.javaObjectType, true)
}
/**
* Allows to register a flow of type [Answer] against an initiating flow of type [I], returning a valure of type [R].
*/
inline fun <I : FlowLogic<*>, reified R : Any> StartedNode<*>.registerAnswer(initiatingFlowType: KClass<I>, value: R) {
internals.internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> Answer(session, value) }, Answer::class.javaObjectType, true)
}
/**
* Extracts data from a [Map[FlowSession, UntrustworthyData<Any>]] without performing checks and casting to [R].
*/
@Suppress("UNCHECKED_CAST")
infix fun <R : Any> Map<FlowSession, UntrustworthyData<Any>>.from(session: FlowSession): R = this[session]!!.unwrap { it as R }
/**
* Creates a [Pair([session], [Class])] from this [Class].
*/
infix fun <T : Class<out Any>> T.from(session: FlowSession): Pair<FlowSession, T> = session to this
/**
* Creates a [Pair([session], [Class])] from this [KClass].
*/
infix fun <T : Any> KClass<T>.from(session: FlowSession): Pair<FlowSession, Class<T>> = session to this.javaObjectType
/**
* Suspends until a message has been received for each session in the specified [sessions].
*
* Consider [receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>>] when the same type is expected from all sessions.
*
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
* corrupted data in order to exploit your code.
*
* @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them.
*/
@Suspendable
fun FlowLogic<*>.receiveAll(session: Pair<FlowSession, Class<out Any>>, vararg sessions: Pair<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>> {
val allSessions = arrayOf(session, *sessions)
allSessions.enforceNoDuplicates()
return receiveAll(mapOf(*allSessions))
}
/**
* Suspends until a message has been received for each session in the specified [sessions].
*
* Consider [sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>>] when sessions are expected to receive different types.
*
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
* corrupted data in order to exploit your code.
*
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions].
*/
@Suspendable
fun <R : Any> FlowLogic<*>.receiveAll(receiveType: Class<R>, session: FlowSession, vararg sessions: FlowSession): List<UntrustworthyData<R>> = receiveAll(receiveType, listOf(session, *sessions))
/**
* Suspends until a message has been received for each session in the specified [sessions].
*
* Consider [sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>>] when sessions are expected to receive different types.
*
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
* corrupted data in order to exploit your code.
*
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions].
*/
@Suspendable
inline fun <reified R : Any> FlowLogic<*>.receiveAll(session: FlowSession, vararg sessions: FlowSession): List<UntrustworthyData<R>> = receiveAll(R::class.javaObjectType, listOf(session, *sessions))
private fun Array<out Pair<FlowSession, Class<out Any>>>.enforceNoDuplicates() {
require(this.size == this.toSet().size) { "A flow session can only appear once as argument." }
}

View File

@ -0,0 +1,87 @@
package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.identity.Party
import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.testing.chooseIdentity
import net.corda.testing.node.network
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
class ReceiveMultipleFlowTests {
@Test
fun `receive all messages in parallel using map style`() {
network(3) { nodes, _ ->
val doubleValue = 5.0
nodes[1].registerAnswer(AlgorithmDefinition::class, doubleValue)
val stringValue = "Thriller"
nodes[2].registerAnswer(AlgorithmDefinition::class, stringValue)
val flow = nodes[0].services.startFlow(ParallelAlgorithmMap(nodes[1].info.chooseIdentity(), nodes[2].info.chooseIdentity()))
runNetwork()
val result = flow.resultFuture.getOrThrow()
assertThat(result).isEqualTo(doubleValue * stringValue.length)
}
}
@Test
fun `receive all messages in parallel using list style`() {
network(3) { nodes, _ ->
val value1 = 5.0
nodes[1].registerAnswer(ParallelAlgorithmList::class, value1)
val value2 = 6.0
nodes[2].registerAnswer(ParallelAlgorithmList::class, value2)
val flow = nodes[0].services.startFlow(ParallelAlgorithmList(nodes[1].info.chooseIdentity(), nodes[2].info.chooseIdentity()))
runNetwork()
val data = flow.resultFuture.getOrThrow()
assertThat(data[0]).isEqualTo(value1)
assertThat(data[1]).isEqualTo(value2)
assertThat(data.fold(1.0) { a, b -> a * b }).isEqualTo(value1 * value2)
}
}
class ParallelAlgorithmMap(doubleMember: Party, stringMember: Party) : AlgorithmDefinition(doubleMember, stringMember) {
@Suspendable
override fun askMembersForData(doubleMember: Party, stringMember: Party): Data {
val doubleSession = initiateFlow(doubleMember)
val stringSession = initiateFlow(stringMember)
val rawData = receiveAll(Double::class from doubleSession, String::class from stringSession)
return Data(rawData from doubleSession, rawData from stringSession)
}
}
@InitiatingFlow
class ParallelAlgorithmList(private val member1: Party, private val member2: Party) : FlowLogic<List<Double>>() {
@Suspendable
override fun call(): List<Double> {
val session1 = initiateFlow(member1)
val session2 = initiateFlow(member2)
val data = receiveAll<Double>(session1, session2)
return computeAnswer(data)
}
private fun computeAnswer(data: List<UntrustworthyData<Double>>): List<Double> {
return data.map { element -> element.unwrap { it } }
}
}
@InitiatingFlow
abstract class AlgorithmDefinition(private val doubleMember: Party, private val stringMember: Party) : FlowLogic<Double>() {
protected data class Data(val double: Double, val string: String)
@Suspendable
protected abstract fun askMembersForData(doubleMember: Party, stringMember: Party): Data
@Suspendable
override fun call(): Double {
val (double, string) = askMembersForData(doubleMember, stringMember)
return double * string.length
}
}
}

View File

@ -6,6 +6,7 @@ from the previous milestone release.
UNRELEASED UNRELEASED
---------- ----------
* ``FlowLogic`` now exposes a series of function called ``receiveAll(...)`` allowing to join ``receive(...)`` instructions.
* ``Cordform`` and node identity generation * ``Cordform`` and node identity generation
* Cordform may not specify a value for ``NetworkMap``, when that happens, during the task execution the following happens: * Cordform may not specify a value for ``NetworkMap``, when that happens, during the task execution the following happens:

View File

@ -1,5 +1,6 @@
package net.corda.node.services.statemachine package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
interface FlowIORequest { interface FlowIORequest {
@ -8,7 +9,9 @@ interface FlowIORequest {
val stackTraceInCaseOfProblems: StackSnapshot val stackTraceInCaseOfProblems: StackSnapshot
} }
interface WaitingRequest : FlowIORequest interface WaitingRequest : FlowIORequest {
fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean
}
interface SessionedFlowIORequest : FlowIORequest { interface SessionedFlowIORequest : FlowIORequest {
val session: FlowSessionInternal val session: FlowSessionInternal
@ -21,6 +24,8 @@ interface SendRequest : SessionedFlowIORequest {
interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest, WaitingRequest { interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest, WaitingRequest {
val receiveType: Class<T> val receiveType: Class<T>
val userReceiveType: Class<*>? val userReceiveType: Class<*>?
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session
} }
data class SendAndReceive<T : SessionMessage>(override val session: FlowSessionInternal, data class SendAndReceive<T : SessionMessage>(override val session: FlowSessionInternal,
@ -38,6 +43,63 @@ data class ReceiveOnly<T : SessionMessage>(override val session: FlowSessionInte
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
} }
class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
private fun isComplete(received: LinkedHashMap<FlowSessionInternal, RequestMessage>): Boolean {
return received.keys == requests.map { it.session }.toSet()
}
private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) }
private fun hasSuccessfulEndMessage(it: ReceiveRequest<SessionData>): Boolean {
return it.session.receivedMessages.map { it.message }.any { it is SessionData || it is SessionEnd }
}
@Suspendable
fun suspendAndExpectReceive(suspend: Suspend): Map<FlowSessionInternal, RequestMessage> {
val receivedMessages = LinkedHashMap<FlowSessionInternal, RequestMessage>()
poll(receivedMessages)
return if (isComplete(receivedMessages)) {
receivedMessages
} else {
suspend(this)
poll(receivedMessages)
if (isComplete(receivedMessages)) {
receivedMessages
} else {
throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a ${it.receiveType.simpleName} but instead got nothing for $it." }.joinToString { "\n" })
}
}
}
interface Suspend {
@Suspendable
operator fun invoke(request: FlowIORequest)
}
@Suspendable
private fun poll(receivedMessages: LinkedHashMap<FlowSessionInternal, RequestMessage>) {
return requests.filter { it.session !in receivedMessages.keys }.forEach { request ->
poll(request)?.let {
receivedMessages[request.session] = RequestMessage(request, it)
}
}
}
@Suspendable
private fun poll(request: ReceiveRequest<SessionData>): ReceivedSessionMessage<*>? {
return request.session.receivedMessages.poll()
}
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant()
private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session }
data class RequestMessage(val request: ReceiveRequest<SessionData>, val message: ReceivedSessionMessage<*>)
}
data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest { data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest {
@Transient @Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
@ -46,6 +108,8 @@ data class SendOnly(override val session: FlowSessionInternal, override val mess
data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest { data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest {
@Transient @Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message is ErrorSessionEnd
} }
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem") class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")

View File

@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
import net.corda.core.flows.FlowInfo import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.node.services.statemachine.FlowSessionState.Initiated import net.corda.node.services.statemachine.FlowSessionState.Initiated
import net.corda.node.services.statemachine.FlowSessionState.Initiating import net.corda.node.services.statemachine.FlowSessionState.Initiating
@ -15,6 +16,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
// TODO rename this // TODO rename this
class FlowSessionInternal( class FlowSessionInternal(
val flow: FlowLogic<*>, val flow: FlowLogic<*>,
val flowSession : FlowSession,
val ourSessionId: Long, val ourSessionId: Long,
val initiatingParty: Party?, val initiatingParty: Party?,
var state: FlowSessionState, var state: FlowSessionState,

View File

@ -12,9 +12,13 @@ import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.* import net.corda.core.flows.*
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.* import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.abbreviate
import net.corda.core.internal.concurrent.OpenFuture import net.corda.core.internal.concurrent.OpenFuture
import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.isRegularFile
import net.corda.core.internal.staticField
import net.corda.core.internal.uncheckedCast
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.* import net.corda.core.utilities.*
import net.corda.node.services.api.FlowAppAuditEvent import net.corda.node.services.api.FlowAppAuditEvent
@ -171,8 +175,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
"@${InitiatingFlow::class.java.simpleName} sub-flow." "@${InitiatingFlow::class.java.simpleName} sub-flow."
) )
} }
createNewSession(otherParty, sessionFlow)
val flowSession = FlowSessionImpl(otherParty) val flowSession = FlowSessionImpl(otherParty)
createNewSession(otherParty, flowSession, sessionFlow)
flowSession.stateMachine = this flowSession.stateMachine = this
flowSession.sessionFlow = sessionFlow flowSession.sessionFlow = sessionFlow
return flowSession return flowSession
@ -299,6 +303,22 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
FlowStackSnapshotFactory.instance.persistAsJsonFile(flowClass, serviceHub.configuration.baseDirectory, id) FlowStackSnapshotFactory.instance.persistAsJsonFile(flowClass, serviceHub.configuration.baseDirectory, id)
} }
@Suspendable
override fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>> {
val requests = ArrayList<ReceiveOnly<SessionData>>()
for ((session, receiveType) in sessions) {
val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow)
requests.add(ReceiveOnly(sessionInternal, SessionData::class.java, receiveType))
}
val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend)
val result = LinkedHashMap<FlowSession, UntrustworthyData<Any>>()
for ((sessionInternal, requestAndMessage) in receivedMessages) {
val message = requestAndMessage.message.confirmReceiveType(requestAndMessage.request)
result[sessionInternal.flowSession] = message.checkPayloadIs(requestAndMessage.request.userReceiveType as Class<out Any>)
}
return result
}
/** /**
* This method will suspend the state machine and wait for incoming session init response from other party. * This method will suspend the state machine and wait for incoming session init response from other party.
*/ */
@ -362,10 +382,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
private fun createNewSession( private fun createNewSession(
otherParty: Party, otherParty: Party,
flowSession: FlowSession,
sessionFlow: FlowLogic<*> sessionFlow: FlowLogic<*>
) { ) {
logger.trace { "Creating a new session with $otherParty" } logger.trace { "Creating a new session with $otherParty" }
val session = FlowSessionInternal(sessionFlow, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty)) val session = FlowSessionInternal(sessionFlow, flowSession, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty))
openSessions[Pair(sessionFlow, otherParty)] = session openSessions[Pair(sessionFlow, otherParty)] = session
} }
@ -402,6 +423,13 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest) return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest)
} }
private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend {
@Suspendable
override fun invoke(request: FlowIORequest) {
suspend(request)
}
}
@Suspendable @Suspendable
private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> { private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> {
val polledMessage = session.receivedMessages.poll() val polledMessage = session.receivedMessages.poll()

View File

@ -15,7 +15,11 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.random63BitValue import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.* import net.corda.core.flows.*
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.* import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.castIfPossible
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
@ -342,8 +346,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
// commit but a counterparty flow has ended with an error (in which case our flow also has to end) // commit but a counterparty flow has ended with an error (in which case our flow also has to end)
private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean { private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean {
val waitingForResponse = session.fiber.waitingForResponse val waitingForResponse = session.fiber.waitingForResponse
return (waitingForResponse as? ReceiveRequest<*>)?.session === session || return waitingForResponse?.shouldResume(message, session) ?: false
waitingForResponse is WaitForLedgerCommit && message is ErrorSessionEnd
} }
private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) { private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) {
@ -362,6 +365,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
val session = FlowSessionInternal( val session = FlowSessionInternal(
flow, flow,
flowSession,
random63BitValue(), random63BitValue(),
sender, sender,
FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName))) FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName)))

View File

@ -51,6 +51,7 @@ import net.corda.testing.resetTestSerialization
import net.corda.testing.testNodeConfiguration import net.corda.testing.testNodeConfiguration
import org.apache.activemq.artemis.utils.ReusableLatch import org.apache.activemq.artemis.utils.ReusableLatch
import org.slf4j.Logger import org.slf4j.Logger
import java.io.Closeable
import java.math.BigInteger import java.math.BigInteger
import java.nio.file.Path import java.nio.file.Path
import java.security.KeyPair import java.security.KeyPair
@ -81,12 +82,11 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy = servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy =
InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random(), InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random(),
private val defaultFactory: Factory<*> = MockNetwork.DefaultFactory, private val defaultFactory: Factory<*> = MockNetwork.DefaultFactory,
private val initialiseSerialization: Boolean = true) { private val initialiseSerialization: Boolean = true) : Closeable {
companion object { companion object {
// TODO In future PR we're removing the concept of network map node so the details of this mock are not important. // TODO In future PR we're removing the concept of network map node so the details of this mock are not important.
val MOCK_NET_MAP = Party(CordaX500Name(organisation = "Mock Network Map", locality = "Madrid", country = "ES"), DUMMY_KEY_1.public) val MOCK_NET_MAP = Party(CordaX500Name(organisation = "Mock Network Map", locality = "Madrid", country = "ES"), DUMMY_KEY_1.public)
} }
var nextNodeId = 0 var nextNodeId = 0
private set private set
private val filesystem = Jimfs.newFileSystem(unix()) private val filesystem = Jimfs.newFileSystem(unix())
@ -445,4 +445,17 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
fun waitQuiescent() { fun waitQuiescent() {
busyLatch.await() busyLatch.await()
} }
override fun close() {
stopNodes()
}
}
fun network(nodesCount: Int, action: MockNetwork.(nodes: List<StartedNode<MockNetwork.MockNode>>, notary: StartedNode<MockNetwork.MockNode>) -> Unit) {
MockNetwork().use {
it.runNetwork()
val notary = it.createNotaryNode()
val nodes = (1..nodesCount).map { _ -> it.createPartyNode() }
action(it, nodes, notary)
}
} }