mirror of
https://github.com/corda/corda.git
synced 2024-12-24 07:06:44 +00:00
[CORDA-683] Enable receiveAll()
from Flows.
This commit is contained in:
parent
20a30b30da
commit
29a101c378
1
.gitignore
vendored
1
.gitignore
vendored
@ -34,6 +34,7 @@ lib/quasar.jar
|
||||
.idea/shelf
|
||||
.idea/dataSources
|
||||
.idea/markdown-navigator
|
||||
.idea/runConfigurations
|
||||
/gradle-plugins/.idea/
|
||||
|
||||
# Include the -parameters compiler option by default in IntelliJ required for serialization.
|
||||
|
@ -6,6 +6,7 @@ import net.corda.core.identity.Party
|
||||
import net.corda.core.identity.PartyAndCertificate
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.core.internal.abbreviate
|
||||
import net.corda.core.internal.uncheckedCast
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.node.NodeInfo
|
||||
import net.corda.core.node.ServiceHub
|
||||
@ -177,6 +178,38 @@ abstract class FlowLogic<out T> {
|
||||
return stateMachine.receive(receiveType, otherParty, flowUsedForSessions)
|
||||
}
|
||||
|
||||
/** Suspends until a message has been received for each session in the specified [sessions].
|
||||
*
|
||||
* Consider [receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>>] when the same type is expected from all sessions.
|
||||
*
|
||||
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
|
||||
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
|
||||
* corrupted data in order to exploit your code.
|
||||
*
|
||||
* @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them.
|
||||
*/
|
||||
@Suspendable
|
||||
open fun receiveAll(sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>> {
|
||||
return stateMachine.receiveAll(sessions, this)
|
||||
}
|
||||
|
||||
/**
|
||||
* Suspends until a message has been received for each session in the specified [sessions].
|
||||
*
|
||||
* Consider [sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>>] when sessions are expected to receive different types.
|
||||
*
|
||||
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
|
||||
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
|
||||
* corrupted data in order to exploit your code.
|
||||
*
|
||||
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions].
|
||||
*/
|
||||
@Suspendable
|
||||
open fun <R : Any> receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>> {
|
||||
enforceNoDuplicates(sessions)
|
||||
return castMapValuesToKnownType(receiveAll(associateSessionsToReceiveType(receiveType, sessions)))
|
||||
}
|
||||
|
||||
/**
|
||||
* Queues the given [payload] for sending to the [otherParty] and continues without suspending.
|
||||
*
|
||||
@ -231,7 +264,6 @@ abstract class FlowLogic<out T> {
|
||||
stateMachine.checkFlowPermission(permissionName, extraAuditData)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Flows can call this method to record application level flow audit events
|
||||
* @param eventType is a string representing the type of event. Each flow is given a distinct namespace for these names.
|
||||
@ -334,6 +366,18 @@ abstract class FlowLogic<out T> {
|
||||
ours.setChildProgressTracker(ours.currentStep, theirs)
|
||||
}
|
||||
}
|
||||
|
||||
private fun enforceNoDuplicates(sessions: List<FlowSession>) {
|
||||
require(sessions.size == sessions.toSet().size) { "A flow session can only appear once as argument." }
|
||||
}
|
||||
|
||||
private fun <R> associateSessionsToReceiveType(receiveType: Class<R>, sessions: List<FlowSession>): Map<FlowSession, Class<R>> {
|
||||
return sessions.associateByTo(LinkedHashMap(), { it }, { receiveType })
|
||||
}
|
||||
|
||||
private fun <R> castMapValuesToKnownType(map: Map<FlowSession, UntrustworthyData<Any>>): List<UntrustworthyData<R>> {
|
||||
return map.values.map { uncheckedCast<Any, UntrustworthyData<R>>(it) }
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -30,20 +30,20 @@ interface FlowStateMachine<R> {
|
||||
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData<T>
|
||||
|
||||
@Suspendable
|
||||
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>): Unit
|
||||
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>)
|
||||
|
||||
@Suspendable
|
||||
fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction
|
||||
|
||||
fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>): Unit
|
||||
fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>)
|
||||
|
||||
fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String, String>): Unit
|
||||
fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String, String>)
|
||||
|
||||
@Suspendable
|
||||
fun flowStackSnapshot(flowClass: Class<out FlowLogic<*>>): FlowStackSnapshot?
|
||||
|
||||
@Suspendable
|
||||
fun persistFlowStackSnapshot(flowClass: Class<out FlowLogic<*>>): Unit
|
||||
fun persistFlowStackSnapshot(flowClass: Class<out FlowLogic<*>>)
|
||||
|
||||
val serviceHub: ServiceHub
|
||||
val logger: Logger
|
||||
@ -51,4 +51,7 @@ interface FlowStateMachine<R> {
|
||||
val resultFuture: CordaFuture<R>
|
||||
val flowInitiator: FlowInitiator
|
||||
val ourIdentityAndCert: PartyAndCertificate
|
||||
|
||||
@Suspendable
|
||||
fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>>
|
||||
}
|
||||
|
115
core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt
Normal file
115
core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt
Normal file
@ -0,0 +1,115 @@
|
||||
package net.corda.core.flows
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.utilities.UntrustworthyData
|
||||
import net.corda.core.utilities.unwrap
|
||||
import net.corda.node.internal.InitiatedFlowFactory
|
||||
import net.corda.node.internal.StartedNode
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
* Allows to simplify writing flows that simply rend a message back to an initiating flow.
|
||||
*/
|
||||
class Answer<out R : Any>(session: FlowSession, override val answer: R, closure: (result: R) -> Unit = {}) : SimpleAnswer<R>(session, closure)
|
||||
|
||||
/**
|
||||
* Allows to simplify writing flows that simply rend a message back to an initiating flow.
|
||||
*/
|
||||
abstract class SimpleAnswer<out R : Any>(private val session: FlowSession, private val closure: (result: R) -> Unit = {}) : FlowLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
val tmp = answer
|
||||
closure(tmp)
|
||||
session.send(tmp)
|
||||
}
|
||||
|
||||
protected abstract val answer: R
|
||||
}
|
||||
|
||||
/**
|
||||
* A flow that does not do anything when triggered.
|
||||
*/
|
||||
class NoAnswer(private val closure: () -> Unit = {}) : FlowLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() = closure()
|
||||
}
|
||||
|
||||
/**
|
||||
* Allows to register a flow of type [R] against an initiating flow of type [I].
|
||||
*/
|
||||
inline fun <I : FlowLogic<*>, reified R : FlowLogic<*>> StartedNode<*>.registerInitiatedFlow(initiatingFlowType: KClass<I>, crossinline construct: (session: FlowSession) -> R) {
|
||||
internals.internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> construct(session) }, R::class.javaObjectType, true)
|
||||
}
|
||||
|
||||
/**
|
||||
* Allows to register a flow of type [Answer] against an initiating flow of type [I], returning a valure of type [R].
|
||||
*/
|
||||
inline fun <I : FlowLogic<*>, reified R : Any> StartedNode<*>.registerAnswer(initiatingFlowType: KClass<I>, value: R) {
|
||||
internals.internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> Answer(session, value) }, Answer::class.javaObjectType, true)
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts data from a [Map[FlowSession, UntrustworthyData<Any>]] without performing checks and casting to [R].
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
infix fun <R : Any> Map<FlowSession, UntrustworthyData<Any>>.from(session: FlowSession): R = this[session]!!.unwrap { it as R }
|
||||
|
||||
/**
|
||||
* Creates a [Pair([session], [Class])] from this [Class].
|
||||
*/
|
||||
infix fun <T : Class<out Any>> T.from(session: FlowSession): Pair<FlowSession, T> = session to this
|
||||
|
||||
/**
|
||||
* Creates a [Pair([session], [Class])] from this [KClass].
|
||||
*/
|
||||
infix fun <T : Any> KClass<T>.from(session: FlowSession): Pair<FlowSession, Class<T>> = session to this.javaObjectType
|
||||
|
||||
/**
|
||||
* Suspends until a message has been received for each session in the specified [sessions].
|
||||
*
|
||||
* Consider [receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>>] when the same type is expected from all sessions.
|
||||
*
|
||||
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
|
||||
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
|
||||
* corrupted data in order to exploit your code.
|
||||
*
|
||||
* @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them.
|
||||
*/
|
||||
@Suspendable
|
||||
fun FlowLogic<*>.receiveAll(session: Pair<FlowSession, Class<out Any>>, vararg sessions: Pair<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>> {
|
||||
val allSessions = arrayOf(session, *sessions)
|
||||
allSessions.enforceNoDuplicates()
|
||||
return receiveAll(mapOf(*allSessions))
|
||||
}
|
||||
|
||||
/**
|
||||
* Suspends until a message has been received for each session in the specified [sessions].
|
||||
*
|
||||
* Consider [sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>>] when sessions are expected to receive different types.
|
||||
*
|
||||
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
|
||||
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
|
||||
* corrupted data in order to exploit your code.
|
||||
*
|
||||
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions].
|
||||
*/
|
||||
@Suspendable
|
||||
fun <R : Any> FlowLogic<*>.receiveAll(receiveType: Class<R>, session: FlowSession, vararg sessions: FlowSession): List<UntrustworthyData<R>> = receiveAll(receiveType, listOf(session, *sessions))
|
||||
|
||||
/**
|
||||
* Suspends until a message has been received for each session in the specified [sessions].
|
||||
*
|
||||
* Consider [sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>>] when sessions are expected to receive different types.
|
||||
*
|
||||
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
|
||||
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
|
||||
* corrupted data in order to exploit your code.
|
||||
*
|
||||
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions].
|
||||
*/
|
||||
@Suspendable
|
||||
inline fun <reified R : Any> FlowLogic<*>.receiveAll(session: FlowSession, vararg sessions: FlowSession): List<UntrustworthyData<R>> = receiveAll(R::class.javaObjectType, listOf(session, *sessions))
|
||||
|
||||
private fun Array<out Pair<FlowSession, Class<out Any>>>.enforceNoDuplicates() {
|
||||
require(this.size == this.toSet().size) { "A flow session can only appear once as argument." }
|
||||
}
|
@ -0,0 +1,87 @@
|
||||
package net.corda.core.flows
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.utilities.UntrustworthyData
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.unwrap
|
||||
import net.corda.testing.chooseIdentity
|
||||
import net.corda.testing.node.network
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Test
|
||||
|
||||
class ReceiveMultipleFlowTests {
|
||||
@Test
|
||||
fun `receive all messages in parallel using map style`() {
|
||||
network(3) { nodes, _ ->
|
||||
val doubleValue = 5.0
|
||||
nodes[1].registerAnswer(AlgorithmDefinition::class, doubleValue)
|
||||
val stringValue = "Thriller"
|
||||
nodes[2].registerAnswer(AlgorithmDefinition::class, stringValue)
|
||||
|
||||
val flow = nodes[0].services.startFlow(ParallelAlgorithmMap(nodes[1].info.chooseIdentity(), nodes[2].info.chooseIdentity()))
|
||||
runNetwork()
|
||||
|
||||
val result = flow.resultFuture.getOrThrow()
|
||||
|
||||
assertThat(result).isEqualTo(doubleValue * stringValue.length)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `receive all messages in parallel using list style`() {
|
||||
network(3) { nodes, _ ->
|
||||
val value1 = 5.0
|
||||
nodes[1].registerAnswer(ParallelAlgorithmList::class, value1)
|
||||
val value2 = 6.0
|
||||
nodes[2].registerAnswer(ParallelAlgorithmList::class, value2)
|
||||
|
||||
val flow = nodes[0].services.startFlow(ParallelAlgorithmList(nodes[1].info.chooseIdentity(), nodes[2].info.chooseIdentity()))
|
||||
runNetwork()
|
||||
val data = flow.resultFuture.getOrThrow()
|
||||
|
||||
assertThat(data[0]).isEqualTo(value1)
|
||||
assertThat(data[1]).isEqualTo(value2)
|
||||
assertThat(data.fold(1.0) { a, b -> a * b }).isEqualTo(value1 * value2)
|
||||
}
|
||||
}
|
||||
|
||||
class ParallelAlgorithmMap(doubleMember: Party, stringMember: Party) : AlgorithmDefinition(doubleMember, stringMember) {
|
||||
@Suspendable
|
||||
override fun askMembersForData(doubleMember: Party, stringMember: Party): Data {
|
||||
val doubleSession = initiateFlow(doubleMember)
|
||||
val stringSession = initiateFlow(stringMember)
|
||||
val rawData = receiveAll(Double::class from doubleSession, String::class from stringSession)
|
||||
return Data(rawData from doubleSession, rawData from stringSession)
|
||||
}
|
||||
}
|
||||
|
||||
@InitiatingFlow
|
||||
class ParallelAlgorithmList(private val member1: Party, private val member2: Party) : FlowLogic<List<Double>>() {
|
||||
@Suspendable
|
||||
override fun call(): List<Double> {
|
||||
val session1 = initiateFlow(member1)
|
||||
val session2 = initiateFlow(member2)
|
||||
val data = receiveAll<Double>(session1, session2)
|
||||
return computeAnswer(data)
|
||||
}
|
||||
|
||||
private fun computeAnswer(data: List<UntrustworthyData<Double>>): List<Double> {
|
||||
return data.map { element -> element.unwrap { it } }
|
||||
}
|
||||
}
|
||||
|
||||
@InitiatingFlow
|
||||
abstract class AlgorithmDefinition(private val doubleMember: Party, private val stringMember: Party) : FlowLogic<Double>() {
|
||||
protected data class Data(val double: Double, val string: String)
|
||||
|
||||
@Suspendable
|
||||
protected abstract fun askMembersForData(doubleMember: Party, stringMember: Party): Data
|
||||
|
||||
@Suspendable
|
||||
override fun call(): Double {
|
||||
val (double, string) = askMembersForData(doubleMember, stringMember)
|
||||
return double * string.length
|
||||
}
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@ from the previous milestone release.
|
||||
|
||||
UNRELEASED
|
||||
----------
|
||||
* ``FlowLogic`` now exposes a series of function called ``receiveAll(...)`` allowing to join ``receive(...)`` instructions.
|
||||
|
||||
* ``Cordform`` and node identity generation
|
||||
* Cordform may not specify a value for ``NetworkMap``, when that happens, during the task execution the following happens:
|
||||
|
@ -1,5 +1,6 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.crypto.SecureHash
|
||||
|
||||
interface FlowIORequest {
|
||||
@ -8,7 +9,9 @@ interface FlowIORequest {
|
||||
val stackTraceInCaseOfProblems: StackSnapshot
|
||||
}
|
||||
|
||||
interface WaitingRequest : FlowIORequest
|
||||
interface WaitingRequest : FlowIORequest {
|
||||
fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean
|
||||
}
|
||||
|
||||
interface SessionedFlowIORequest : FlowIORequest {
|
||||
val session: FlowSessionInternal
|
||||
@ -21,6 +24,8 @@ interface SendRequest : SessionedFlowIORequest {
|
||||
interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest, WaitingRequest {
|
||||
val receiveType: Class<T>
|
||||
val userReceiveType: Class<*>?
|
||||
|
||||
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session
|
||||
}
|
||||
|
||||
data class SendAndReceive<T : SessionMessage>(override val session: FlowSessionInternal,
|
||||
@ -38,6 +43,63 @@ data class ReceiveOnly<T : SessionMessage>(override val session: FlowSessionInte
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
}
|
||||
|
||||
class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingRequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
|
||||
private fun isComplete(received: LinkedHashMap<FlowSessionInternal, RequestMessage>): Boolean {
|
||||
return received.keys == requests.map { it.session }.toSet()
|
||||
}
|
||||
private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) }
|
||||
|
||||
private fun hasSuccessfulEndMessage(it: ReceiveRequest<SessionData>): Boolean {
|
||||
return it.session.receivedMessages.map { it.message }.any { it is SessionData || it is SessionEnd }
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
fun suspendAndExpectReceive(suspend: Suspend): Map<FlowSessionInternal, RequestMessage> {
|
||||
val receivedMessages = LinkedHashMap<FlowSessionInternal, RequestMessage>()
|
||||
|
||||
poll(receivedMessages)
|
||||
return if (isComplete(receivedMessages)) {
|
||||
receivedMessages
|
||||
} else {
|
||||
suspend(this)
|
||||
poll(receivedMessages)
|
||||
if (isComplete(receivedMessages)) {
|
||||
receivedMessages
|
||||
} else {
|
||||
throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a ${it.receiveType.simpleName} but instead got nothing for $it." }.joinToString { "\n" })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
interface Suspend {
|
||||
@Suspendable
|
||||
operator fun invoke(request: FlowIORequest)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun poll(receivedMessages: LinkedHashMap<FlowSessionInternal, RequestMessage>) {
|
||||
return requests.filter { it.session !in receivedMessages.keys }.forEach { request ->
|
||||
poll(request)?.let {
|
||||
receivedMessages[request.session] = RequestMessage(request, it)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun poll(request: ReceiveRequest<SessionData>): ReceivedSessionMessage<*>? {
|
||||
return request.session.receivedMessages.poll()
|
||||
}
|
||||
|
||||
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant()
|
||||
|
||||
private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session }
|
||||
|
||||
data class RequestMessage(val request: ReceiveRequest<SessionData>, val message: ReceivedSessionMessage<*>)
|
||||
}
|
||||
|
||||
data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
@ -46,6 +108,8 @@ data class SendOnly(override val session: FlowSessionInternal, override val mess
|
||||
data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
|
||||
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message is ErrorSessionEnd
|
||||
}
|
||||
|
||||
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")
|
||||
|
@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.FlowSession
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.node.services.statemachine.FlowSessionState.Initiated
|
||||
import net.corda.node.services.statemachine.FlowSessionState.Initiating
|
||||
@ -15,6 +16,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
|
||||
// TODO rename this
|
||||
class FlowSessionInternal(
|
||||
val flow: FlowLogic<*>,
|
||||
val flowSession : FlowSession,
|
||||
val ourSessionId: Long,
|
||||
val initiatingParty: Party?,
|
||||
var state: FlowSessionState,
|
||||
|
@ -12,9 +12,13 @@ import net.corda.core.crypto.random63BitValue
|
||||
import net.corda.core.flows.*
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.identity.PartyAndCertificate
|
||||
import net.corda.core.internal.*
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.core.internal.abbreviate
|
||||
import net.corda.core.internal.concurrent.OpenFuture
|
||||
import net.corda.core.internal.concurrent.openFuture
|
||||
import net.corda.core.internal.isRegularFile
|
||||
import net.corda.core.internal.staticField
|
||||
import net.corda.core.internal.uncheckedCast
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.core.utilities.*
|
||||
import net.corda.node.services.api.FlowAppAuditEvent
|
||||
@ -171,8 +175,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
"@${InitiatingFlow::class.java.simpleName} sub-flow."
|
||||
)
|
||||
}
|
||||
createNewSession(otherParty, sessionFlow)
|
||||
val flowSession = FlowSessionImpl(otherParty)
|
||||
createNewSession(otherParty, flowSession, sessionFlow)
|
||||
flowSession.stateMachine = this
|
||||
flowSession.sessionFlow = sessionFlow
|
||||
return flowSession
|
||||
@ -299,6 +303,22 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
FlowStackSnapshotFactory.instance.persistAsJsonFile(flowClass, serviceHub.configuration.baseDirectory, id)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>> {
|
||||
val requests = ArrayList<ReceiveOnly<SessionData>>()
|
||||
for ((session, receiveType) in sessions) {
|
||||
val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow)
|
||||
requests.add(ReceiveOnly(sessionInternal, SessionData::class.java, receiveType))
|
||||
}
|
||||
val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend)
|
||||
val result = LinkedHashMap<FlowSession, UntrustworthyData<Any>>()
|
||||
for ((sessionInternal, requestAndMessage) in receivedMessages) {
|
||||
val message = requestAndMessage.message.confirmReceiveType(requestAndMessage.request)
|
||||
result[sessionInternal.flowSession] = message.checkPayloadIs(requestAndMessage.request.userReceiveType as Class<out Any>)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* This method will suspend the state machine and wait for incoming session init response from other party.
|
||||
*/
|
||||
@ -362,10 +382,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
|
||||
private fun createNewSession(
|
||||
otherParty: Party,
|
||||
flowSession: FlowSession,
|
||||
sessionFlow: FlowLogic<*>
|
||||
) {
|
||||
logger.trace { "Creating a new session with $otherParty" }
|
||||
val session = FlowSessionInternal(sessionFlow, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty))
|
||||
val session = FlowSessionInternal(sessionFlow, flowSession, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty))
|
||||
openSessions[Pair(sessionFlow, otherParty)] = session
|
||||
}
|
||||
|
||||
@ -402,6 +423,13 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest)
|
||||
}
|
||||
|
||||
private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend {
|
||||
@Suspendable
|
||||
override fun invoke(request: FlowIORequest) {
|
||||
suspend(request)
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> {
|
||||
val polledMessage = session.receivedMessages.poll()
|
||||
|
@ -15,7 +15,11 @@ import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.random63BitValue
|
||||
import net.corda.core.flows.*
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.*
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.core.internal.ThreadBox
|
||||
import net.corda.core.internal.bufferUntilSubscribed
|
||||
import net.corda.core.internal.castIfPossible
|
||||
import net.corda.core.internal.uncheckedCast
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
|
||||
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
|
||||
@ -342,8 +346,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
// commit but a counterparty flow has ended with an error (in which case our flow also has to end)
|
||||
private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean {
|
||||
val waitingForResponse = session.fiber.waitingForResponse
|
||||
return (waitingForResponse as? ReceiveRequest<*>)?.session === session ||
|
||||
waitingForResponse is WaitForLedgerCommit && message is ErrorSessionEnd
|
||||
return waitingForResponse?.shouldResume(message, session) ?: false
|
||||
}
|
||||
|
||||
private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) {
|
||||
@ -362,6 +365,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
}
|
||||
val session = FlowSessionInternal(
|
||||
flow,
|
||||
flowSession,
|
||||
random63BitValue(),
|
||||
sender,
|
||||
FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName)))
|
||||
|
@ -51,6 +51,7 @@ import net.corda.testing.resetTestSerialization
|
||||
import net.corda.testing.testNodeConfiguration
|
||||
import org.apache.activemq.artemis.utils.ReusableLatch
|
||||
import org.slf4j.Logger
|
||||
import java.io.Closeable
|
||||
import java.math.BigInteger
|
||||
import java.nio.file.Path
|
||||
import java.security.KeyPair
|
||||
@ -81,12 +82,11 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
|
||||
servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy =
|
||||
InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random(),
|
||||
private val defaultFactory: Factory<*> = MockNetwork.DefaultFactory,
|
||||
private val initialiseSerialization: Boolean = true) {
|
||||
private val initialiseSerialization: Boolean = true) : Closeable {
|
||||
companion object {
|
||||
// TODO In future PR we're removing the concept of network map node so the details of this mock are not important.
|
||||
val MOCK_NET_MAP = Party(CordaX500Name(organisation = "Mock Network Map", locality = "Madrid", country = "ES"), DUMMY_KEY_1.public)
|
||||
}
|
||||
|
||||
var nextNodeId = 0
|
||||
private set
|
||||
private val filesystem = Jimfs.newFileSystem(unix())
|
||||
@ -445,4 +445,17 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
|
||||
fun waitQuiescent() {
|
||||
busyLatch.await()
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
stopNodes()
|
||||
}
|
||||
}
|
||||
|
||||
fun network(nodesCount: Int, action: MockNetwork.(nodes: List<StartedNode<MockNetwork.MockNode>>, notary: StartedNode<MockNetwork.MockNode>) -> Unit) {
|
||||
MockNetwork().use {
|
||||
it.runNetwork()
|
||||
val notary = it.createNotaryNode()
|
||||
val nodes = (1..nodesCount).map { _ -> it.createPartyNode() }
|
||||
action(it, nodes, notary)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user