mirror of
https://github.com/corda/corda.git
synced 2024-12-24 07:06:44 +00:00
[CORDA-683] Enable receiveAll()
from Flows.
This commit is contained in:
parent
20a30b30da
commit
29a101c378
1
.gitignore
vendored
1
.gitignore
vendored
@ -34,6 +34,7 @@ lib/quasar.jar
|
|||||||
.idea/shelf
|
.idea/shelf
|
||||||
.idea/dataSources
|
.idea/dataSources
|
||||||
.idea/markdown-navigator
|
.idea/markdown-navigator
|
||||||
|
.idea/runConfigurations
|
||||||
/gradle-plugins/.idea/
|
/gradle-plugins/.idea/
|
||||||
|
|
||||||
# Include the -parameters compiler option by default in IntelliJ required for serialization.
|
# Include the -parameters compiler option by default in IntelliJ required for serialization.
|
||||||
|
@ -6,6 +6,7 @@ import net.corda.core.identity.Party
|
|||||||
import net.corda.core.identity.PartyAndCertificate
|
import net.corda.core.identity.PartyAndCertificate
|
||||||
import net.corda.core.internal.FlowStateMachine
|
import net.corda.core.internal.FlowStateMachine
|
||||||
import net.corda.core.internal.abbreviate
|
import net.corda.core.internal.abbreviate
|
||||||
|
import net.corda.core.internal.uncheckedCast
|
||||||
import net.corda.core.messaging.DataFeed
|
import net.corda.core.messaging.DataFeed
|
||||||
import net.corda.core.node.NodeInfo
|
import net.corda.core.node.NodeInfo
|
||||||
import net.corda.core.node.ServiceHub
|
import net.corda.core.node.ServiceHub
|
||||||
@ -177,6 +178,38 @@ abstract class FlowLogic<out T> {
|
|||||||
return stateMachine.receive(receiveType, otherParty, flowUsedForSessions)
|
return stateMachine.receive(receiveType, otherParty, flowUsedForSessions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Suspends until a message has been received for each session in the specified [sessions].
|
||||||
|
*
|
||||||
|
* Consider [receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>>] when the same type is expected from all sessions.
|
||||||
|
*
|
||||||
|
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
|
||||||
|
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
|
||||||
|
* corrupted data in order to exploit your code.
|
||||||
|
*
|
||||||
|
* @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them.
|
||||||
|
*/
|
||||||
|
@Suspendable
|
||||||
|
open fun receiveAll(sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>> {
|
||||||
|
return stateMachine.receiveAll(sessions, this)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Suspends until a message has been received for each session in the specified [sessions].
|
||||||
|
*
|
||||||
|
* Consider [sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>>] when sessions are expected to receive different types.
|
||||||
|
*
|
||||||
|
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
|
||||||
|
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
|
||||||
|
* corrupted data in order to exploit your code.
|
||||||
|
*
|
||||||
|
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions].
|
||||||
|
*/
|
||||||
|
@Suspendable
|
||||||
|
open fun <R : Any> receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>> {
|
||||||
|
enforceNoDuplicates(sessions)
|
||||||
|
return castMapValuesToKnownType(receiveAll(associateSessionsToReceiveType(receiveType, sessions)))
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Queues the given [payload] for sending to the [otherParty] and continues without suspending.
|
* Queues the given [payload] for sending to the [otherParty] and continues without suspending.
|
||||||
*
|
*
|
||||||
@ -231,7 +264,6 @@ abstract class FlowLogic<out T> {
|
|||||||
stateMachine.checkFlowPermission(permissionName, extraAuditData)
|
stateMachine.checkFlowPermission(permissionName, extraAuditData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Flows can call this method to record application level flow audit events
|
* Flows can call this method to record application level flow audit events
|
||||||
* @param eventType is a string representing the type of event. Each flow is given a distinct namespace for these names.
|
* @param eventType is a string representing the type of event. Each flow is given a distinct namespace for these names.
|
||||||
@ -334,6 +366,18 @@ abstract class FlowLogic<out T> {
|
|||||||
ours.setChildProgressTracker(ours.currentStep, theirs)
|
ours.setChildProgressTracker(ours.currentStep, theirs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun enforceNoDuplicates(sessions: List<FlowSession>) {
|
||||||
|
require(sessions.size == sessions.toSet().size) { "A flow session can only appear once as argument." }
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun <R> associateSessionsToReceiveType(receiveType: Class<R>, sessions: List<FlowSession>): Map<FlowSession, Class<R>> {
|
||||||
|
return sessions.associateByTo(LinkedHashMap(), { it }, { receiveType })
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun <R> castMapValuesToKnownType(map: Map<FlowSession, UntrustworthyData<Any>>): List<UntrustworthyData<R>> {
|
||||||
|
return map.values.map { uncheckedCast<Any, UntrustworthyData<R>>(it) }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -30,20 +30,20 @@ interface FlowStateMachine<R> {
|
|||||||
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData<T>
|
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData<T>
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>): Unit
|
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>)
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction
|
fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction
|
||||||
|
|
||||||
fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>): Unit
|
fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>)
|
||||||
|
|
||||||
fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String, String>): Unit
|
fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String, String>)
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun flowStackSnapshot(flowClass: Class<out FlowLogic<*>>): FlowStackSnapshot?
|
fun flowStackSnapshot(flowClass: Class<out FlowLogic<*>>): FlowStackSnapshot?
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun persistFlowStackSnapshot(flowClass: Class<out FlowLogic<*>>): Unit
|
fun persistFlowStackSnapshot(flowClass: Class<out FlowLogic<*>>)
|
||||||
|
|
||||||
val serviceHub: ServiceHub
|
val serviceHub: ServiceHub
|
||||||
val logger: Logger
|
val logger: Logger
|
||||||
@ -51,4 +51,7 @@ interface FlowStateMachine<R> {
|
|||||||
val resultFuture: CordaFuture<R>
|
val resultFuture: CordaFuture<R>
|
||||||
val flowInitiator: FlowInitiator
|
val flowInitiator: FlowInitiator
|
||||||
val ourIdentityAndCert: PartyAndCertificate
|
val ourIdentityAndCert: PartyAndCertificate
|
||||||
|
|
||||||
|
@Suspendable
|
||||||
|
fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>>
|
||||||
}
|
}
|
||||||
|
115
core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt
Normal file
115
core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
package net.corda.core.flows
|
||||||
|
|
||||||
|
import co.paralleluniverse.fibers.Suspendable
|
||||||
|
import net.corda.core.utilities.UntrustworthyData
|
||||||
|
import net.corda.core.utilities.unwrap
|
||||||
|
import net.corda.node.internal.InitiatedFlowFactory
|
||||||
|
import net.corda.node.internal.StartedNode
|
||||||
|
import kotlin.reflect.KClass
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allows to simplify writing flows that simply rend a message back to an initiating flow.
|
||||||
|
*/
|
||||||
|
class Answer<out R : Any>(session: FlowSession, override val answer: R, closure: (result: R) -> Unit = {}) : SimpleAnswer<R>(session, closure)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allows to simplify writing flows that simply rend a message back to an initiating flow.
|
||||||
|
*/
|
||||||
|
abstract class SimpleAnswer<out R : Any>(private val session: FlowSession, private val closure: (result: R) -> Unit = {}) : FlowLogic<Unit>() {
|
||||||
|
@Suspendable
|
||||||
|
override fun call() {
|
||||||
|
val tmp = answer
|
||||||
|
closure(tmp)
|
||||||
|
session.send(tmp)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected abstract val answer: R
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A flow that does not do anything when triggered.
|
||||||
|
*/
|
||||||
|
class NoAnswer(private val closure: () -> Unit = {}) : FlowLogic<Unit>() {
|
||||||
|
@Suspendable
|
||||||
|
override fun call() = closure()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allows to register a flow of type [R] against an initiating flow of type [I].
|
||||||
|
*/
|
||||||
|
inline fun <I : FlowLogic<*>, reified R : FlowLogic<*>> StartedNode<*>.registerInitiatedFlow(initiatingFlowType: KClass<I>, crossinline construct: (session: FlowSession) -> R) {
|
||||||
|
internals.internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> construct(session) }, R::class.javaObjectType, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Allows to register a flow of type [Answer] against an initiating flow of type [I], returning a valure of type [R].
|
||||||
|
*/
|
||||||
|
inline fun <I : FlowLogic<*>, reified R : Any> StartedNode<*>.registerAnswer(initiatingFlowType: KClass<I>, value: R) {
|
||||||
|
internals.internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> Answer(session, value) }, Answer::class.javaObjectType, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts data from a [Map[FlowSession, UntrustworthyData<Any>]] without performing checks and casting to [R].
|
||||||
|
*/
|
||||||
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
infix fun <R : Any> Map<FlowSession, UntrustworthyData<Any>>.from(session: FlowSession): R = this[session]!!.unwrap { it as R }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a [Pair([session], [Class])] from this [Class].
|
||||||
|
*/
|
||||||
|
infix fun <T : Class<out Any>> T.from(session: FlowSession): Pair<FlowSession, T> = session to this
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a [Pair([session], [Class])] from this [KClass].
|
||||||
|
*/
|
||||||
|
infix fun <T : Any> KClass<T>.from(session: FlowSession): Pair<FlowSession, Class<T>> = session to this.javaObjectType
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Suspends until a message has been received for each session in the specified [sessions].
|
||||||
|
*
|
||||||
|
* Consider [receiveAll(receiveType: Class<R>, sessions: List<FlowSession>): List<UntrustworthyData<R>>] when the same type is expected from all sessions.
|
||||||
|
*
|
||||||
|
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
|
||||||
|
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
|
||||||
|
* corrupted data in order to exploit your code.
|
||||||
|
*
|
||||||
|
* @returns a [Map] containing the objects received, wrapped in an [UntrustworthyData], by the [FlowSession]s who sent them.
|
||||||
|
*/
|
||||||
|
@Suspendable
|
||||||
|
fun FlowLogic<*>.receiveAll(session: Pair<FlowSession, Class<out Any>>, vararg sessions: Pair<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>> {
|
||||||
|
val allSessions = arrayOf(session, *sessions)
|
||||||
|
allSessions.enforceNoDuplicates()
|
||||||
|
return receiveAll(mapOf(*allSessions))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Suspends until a message has been received for each session in the specified [sessions].
|
||||||
|
*
|
||||||
|
* Consider [sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>>] when sessions are expected to receive different types.
|
||||||
|
*
|
||||||
|
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
|
||||||
|
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
|
||||||
|
* corrupted data in order to exploit your code.
|
||||||
|
*
|
||||||
|
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions].
|
||||||
|
*/
|
||||||
|
@Suspendable
|
||||||
|
fun <R : Any> FlowLogic<*>.receiveAll(receiveType: Class<R>, session: FlowSession, vararg sessions: FlowSession): List<UntrustworthyData<R>> = receiveAll(receiveType, listOf(session, *sessions))
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Suspends until a message has been received for each session in the specified [sessions].
|
||||||
|
*
|
||||||
|
* Consider [sessions: Map<FlowSession, Class<out Any>>): Map<FlowSession, UntrustworthyData<Any>>] when sessions are expected to receive different types.
|
||||||
|
*
|
||||||
|
* Remember that when receiving data from other parties the data should not be trusted until it's been thoroughly
|
||||||
|
* verified for consistency and that all expectations are satisfied, as a malicious peer may send you subtly
|
||||||
|
* corrupted data in order to exploit your code.
|
||||||
|
*
|
||||||
|
* @returns a [List] containing the objects received, wrapped in an [UntrustworthyData], with the same order of [sessions].
|
||||||
|
*/
|
||||||
|
@Suspendable
|
||||||
|
inline fun <reified R : Any> FlowLogic<*>.receiveAll(session: FlowSession, vararg sessions: FlowSession): List<UntrustworthyData<R>> = receiveAll(R::class.javaObjectType, listOf(session, *sessions))
|
||||||
|
|
||||||
|
private fun Array<out Pair<FlowSession, Class<out Any>>>.enforceNoDuplicates() {
|
||||||
|
require(this.size == this.toSet().size) { "A flow session can only appear once as argument." }
|
||||||
|
}
|
@ -0,0 +1,87 @@
|
|||||||
|
package net.corda.core.flows
|
||||||
|
|
||||||
|
import co.paralleluniverse.fibers.Suspendable
|
||||||
|
import net.corda.core.identity.Party
|
||||||
|
import net.corda.core.utilities.UntrustworthyData
|
||||||
|
import net.corda.core.utilities.getOrThrow
|
||||||
|
import net.corda.core.utilities.unwrap
|
||||||
|
import net.corda.testing.chooseIdentity
|
||||||
|
import net.corda.testing.node.network
|
||||||
|
import org.assertj.core.api.Assertions.assertThat
|
||||||
|
import org.junit.Test
|
||||||
|
|
||||||
|
class ReceiveMultipleFlowTests {
|
||||||
|
@Test
|
||||||
|
fun `receive all messages in parallel using map style`() {
|
||||||
|
network(3) { nodes, _ ->
|
||||||
|
val doubleValue = 5.0
|
||||||
|
nodes[1].registerAnswer(AlgorithmDefinition::class, doubleValue)
|
||||||
|
val stringValue = "Thriller"
|
||||||
|
nodes[2].registerAnswer(AlgorithmDefinition::class, stringValue)
|
||||||
|
|
||||||
|
val flow = nodes[0].services.startFlow(ParallelAlgorithmMap(nodes[1].info.chooseIdentity(), nodes[2].info.chooseIdentity()))
|
||||||
|
runNetwork()
|
||||||
|
|
||||||
|
val result = flow.resultFuture.getOrThrow()
|
||||||
|
|
||||||
|
assertThat(result).isEqualTo(doubleValue * stringValue.length)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `receive all messages in parallel using list style`() {
|
||||||
|
network(3) { nodes, _ ->
|
||||||
|
val value1 = 5.0
|
||||||
|
nodes[1].registerAnswer(ParallelAlgorithmList::class, value1)
|
||||||
|
val value2 = 6.0
|
||||||
|
nodes[2].registerAnswer(ParallelAlgorithmList::class, value2)
|
||||||
|
|
||||||
|
val flow = nodes[0].services.startFlow(ParallelAlgorithmList(nodes[1].info.chooseIdentity(), nodes[2].info.chooseIdentity()))
|
||||||
|
runNetwork()
|
||||||
|
val data = flow.resultFuture.getOrThrow()
|
||||||
|
|
||||||
|
assertThat(data[0]).isEqualTo(value1)
|
||||||
|
assertThat(data[1]).isEqualTo(value2)
|
||||||
|
assertThat(data.fold(1.0) { a, b -> a * b }).isEqualTo(value1 * value2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class ParallelAlgorithmMap(doubleMember: Party, stringMember: Party) : AlgorithmDefinition(doubleMember, stringMember) {
|
||||||
|
@Suspendable
|
||||||
|
override fun askMembersForData(doubleMember: Party, stringMember: Party): Data {
|
||||||
|
val doubleSession = initiateFlow(doubleMember)
|
||||||
|
val stringSession = initiateFlow(stringMember)
|
||||||
|
val rawData = receiveAll(Double::class from doubleSession, String::class from stringSession)
|
||||||
|
return Data(rawData from doubleSession, rawData from stringSession)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@InitiatingFlow
|
||||||
|
class ParallelAlgorithmList(private val member1: Party, private val member2: Party) : FlowLogic<List<Double>>() {
|
||||||
|
@Suspendable
|
||||||
|
override fun call(): List<Double> {
|
||||||
|
val session1 = initiateFlow(member1)
|
||||||
|
val session2 = initiateFlow(member2)
|
||||||
|
val data = receiveAll<Double>(session1, session2)
|
||||||
|
return computeAnswer(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun computeAnswer(data: List<UntrustworthyData<Double>>): List<Double> {
|
||||||
|
return data.map { element -> element.unwrap { it } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@InitiatingFlow
|
||||||
|
abstract class AlgorithmDefinition(private val doubleMember: Party, private val stringMember: Party) : FlowLogic<Double>() {
|
||||||
|
protected data class Data(val double: Double, val string: String)
|
||||||
|
|
||||||
|
@Suspendable
|
||||||
|
protected abstract fun askMembersForData(doubleMember: Party, stringMember: Party): Data
|
||||||
|
|
||||||
|
@Suspendable
|
||||||
|
override fun call(): Double {
|
||||||
|
val (double, string) = askMembersForData(doubleMember, stringMember)
|
||||||
|
return double * string.length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -6,6 +6,7 @@ from the previous milestone release.
|
|||||||
|
|
||||||
UNRELEASED
|
UNRELEASED
|
||||||
----------
|
----------
|
||||||
|
* ``FlowLogic`` now exposes a series of function called ``receiveAll(...)`` allowing to join ``receive(...)`` instructions.
|
||||||
|
|
||||||
* ``Cordform`` and node identity generation
|
* ``Cordform`` and node identity generation
|
||||||
* Cordform may not specify a value for ``NetworkMap``, when that happens, during the task execution the following happens:
|
* Cordform may not specify a value for ``NetworkMap``, when that happens, during the task execution the following happens:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
package net.corda.node.services.statemachine
|
package net.corda.node.services.statemachine
|
||||||
|
|
||||||
|
import co.paralleluniverse.fibers.Suspendable
|
||||||
import net.corda.core.crypto.SecureHash
|
import net.corda.core.crypto.SecureHash
|
||||||
|
|
||||||
interface FlowIORequest {
|
interface FlowIORequest {
|
||||||
@ -8,7 +9,9 @@ interface FlowIORequest {
|
|||||||
val stackTraceInCaseOfProblems: StackSnapshot
|
val stackTraceInCaseOfProblems: StackSnapshot
|
||||||
}
|
}
|
||||||
|
|
||||||
interface WaitingRequest : FlowIORequest
|
interface WaitingRequest : FlowIORequest {
|
||||||
|
fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean
|
||||||
|
}
|
||||||
|
|
||||||
interface SessionedFlowIORequest : FlowIORequest {
|
interface SessionedFlowIORequest : FlowIORequest {
|
||||||
val session: FlowSessionInternal
|
val session: FlowSessionInternal
|
||||||
@ -21,6 +24,8 @@ interface SendRequest : SessionedFlowIORequest {
|
|||||||
interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest, WaitingRequest {
|
interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest, WaitingRequest {
|
||||||
val receiveType: Class<T>
|
val receiveType: Class<T>
|
||||||
val userReceiveType: Class<*>?
|
val userReceiveType: Class<*>?
|
||||||
|
|
||||||
|
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session
|
||||||
}
|
}
|
||||||
|
|
||||||
data class SendAndReceive<T : SessionMessage>(override val session: FlowSessionInternal,
|
data class SendAndReceive<T : SessionMessage>(override val session: FlowSessionInternal,
|
||||||
@ -38,6 +43,63 @@ data class ReceiveOnly<T : SessionMessage>(override val session: FlowSessionInte
|
|||||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingRequest {
|
||||||
|
@Transient
|
||||||
|
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||||
|
|
||||||
|
private fun isComplete(received: LinkedHashMap<FlowSessionInternal, RequestMessage>): Boolean {
|
||||||
|
return received.keys == requests.map { it.session }.toSet()
|
||||||
|
}
|
||||||
|
private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) }
|
||||||
|
|
||||||
|
private fun hasSuccessfulEndMessage(it: ReceiveRequest<SessionData>): Boolean {
|
||||||
|
return it.session.receivedMessages.map { it.message }.any { it is SessionData || it is SessionEnd }
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suspendable
|
||||||
|
fun suspendAndExpectReceive(suspend: Suspend): Map<FlowSessionInternal, RequestMessage> {
|
||||||
|
val receivedMessages = LinkedHashMap<FlowSessionInternal, RequestMessage>()
|
||||||
|
|
||||||
|
poll(receivedMessages)
|
||||||
|
return if (isComplete(receivedMessages)) {
|
||||||
|
receivedMessages
|
||||||
|
} else {
|
||||||
|
suspend(this)
|
||||||
|
poll(receivedMessages)
|
||||||
|
if (isComplete(receivedMessages)) {
|
||||||
|
receivedMessages
|
||||||
|
} else {
|
||||||
|
throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a ${it.receiveType.simpleName} but instead got nothing for $it." }.joinToString { "\n" })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface Suspend {
|
||||||
|
@Suspendable
|
||||||
|
operator fun invoke(request: FlowIORequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suspendable
|
||||||
|
private fun poll(receivedMessages: LinkedHashMap<FlowSessionInternal, RequestMessage>) {
|
||||||
|
return requests.filter { it.session !in receivedMessages.keys }.forEach { request ->
|
||||||
|
poll(request)?.let {
|
||||||
|
receivedMessages[request.session] = RequestMessage(request, it)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suspendable
|
||||||
|
private fun poll(request: ReceiveRequest<SessionData>): ReceivedSessionMessage<*>? {
|
||||||
|
return request.session.receivedMessages.poll()
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant()
|
||||||
|
|
||||||
|
private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session }
|
||||||
|
|
||||||
|
data class RequestMessage(val request: ReceiveRequest<SessionData>, val message: ReceivedSessionMessage<*>)
|
||||||
|
}
|
||||||
|
|
||||||
data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest {
|
data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest {
|
||||||
@Transient
|
@Transient
|
||||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||||
@ -46,6 +108,8 @@ data class SendOnly(override val session: FlowSessionInternal, override val mess
|
|||||||
data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest {
|
data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest {
|
||||||
@Transient
|
@Transient
|
||||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||||
|
|
||||||
|
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message is ErrorSessionEnd
|
||||||
}
|
}
|
||||||
|
|
||||||
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")
|
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")
|
||||||
|
@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
|
|||||||
|
|
||||||
import net.corda.core.flows.FlowInfo
|
import net.corda.core.flows.FlowInfo
|
||||||
import net.corda.core.flows.FlowLogic
|
import net.corda.core.flows.FlowLogic
|
||||||
|
import net.corda.core.flows.FlowSession
|
||||||
import net.corda.core.identity.Party
|
import net.corda.core.identity.Party
|
||||||
import net.corda.node.services.statemachine.FlowSessionState.Initiated
|
import net.corda.node.services.statemachine.FlowSessionState.Initiated
|
||||||
import net.corda.node.services.statemachine.FlowSessionState.Initiating
|
import net.corda.node.services.statemachine.FlowSessionState.Initiating
|
||||||
@ -15,6 +16,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
|
|||||||
// TODO rename this
|
// TODO rename this
|
||||||
class FlowSessionInternal(
|
class FlowSessionInternal(
|
||||||
val flow: FlowLogic<*>,
|
val flow: FlowLogic<*>,
|
||||||
|
val flowSession : FlowSession,
|
||||||
val ourSessionId: Long,
|
val ourSessionId: Long,
|
||||||
val initiatingParty: Party?,
|
val initiatingParty: Party?,
|
||||||
var state: FlowSessionState,
|
var state: FlowSessionState,
|
||||||
|
@ -12,9 +12,13 @@ import net.corda.core.crypto.random63BitValue
|
|||||||
import net.corda.core.flows.*
|
import net.corda.core.flows.*
|
||||||
import net.corda.core.identity.Party
|
import net.corda.core.identity.Party
|
||||||
import net.corda.core.identity.PartyAndCertificate
|
import net.corda.core.identity.PartyAndCertificate
|
||||||
import net.corda.core.internal.*
|
import net.corda.core.internal.FlowStateMachine
|
||||||
|
import net.corda.core.internal.abbreviate
|
||||||
import net.corda.core.internal.concurrent.OpenFuture
|
import net.corda.core.internal.concurrent.OpenFuture
|
||||||
import net.corda.core.internal.concurrent.openFuture
|
import net.corda.core.internal.concurrent.openFuture
|
||||||
|
import net.corda.core.internal.isRegularFile
|
||||||
|
import net.corda.core.internal.staticField
|
||||||
|
import net.corda.core.internal.uncheckedCast
|
||||||
import net.corda.core.transactions.SignedTransaction
|
import net.corda.core.transactions.SignedTransaction
|
||||||
import net.corda.core.utilities.*
|
import net.corda.core.utilities.*
|
||||||
import net.corda.node.services.api.FlowAppAuditEvent
|
import net.corda.node.services.api.FlowAppAuditEvent
|
||||||
@ -171,8 +175,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
"@${InitiatingFlow::class.java.simpleName} sub-flow."
|
"@${InitiatingFlow::class.java.simpleName} sub-flow."
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
createNewSession(otherParty, sessionFlow)
|
|
||||||
val flowSession = FlowSessionImpl(otherParty)
|
val flowSession = FlowSessionImpl(otherParty)
|
||||||
|
createNewSession(otherParty, flowSession, sessionFlow)
|
||||||
flowSession.stateMachine = this
|
flowSession.stateMachine = this
|
||||||
flowSession.sessionFlow = sessionFlow
|
flowSession.sessionFlow = sessionFlow
|
||||||
return flowSession
|
return flowSession
|
||||||
@ -299,6 +303,22 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
FlowStackSnapshotFactory.instance.persistAsJsonFile(flowClass, serviceHub.configuration.baseDirectory, id)
|
FlowStackSnapshotFactory.instance.persistAsJsonFile(flowClass, serviceHub.configuration.baseDirectory, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Suspendable
|
||||||
|
override fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>> {
|
||||||
|
val requests = ArrayList<ReceiveOnly<SessionData>>()
|
||||||
|
for ((session, receiveType) in sessions) {
|
||||||
|
val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow)
|
||||||
|
requests.add(ReceiveOnly(sessionInternal, SessionData::class.java, receiveType))
|
||||||
|
}
|
||||||
|
val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend)
|
||||||
|
val result = LinkedHashMap<FlowSession, UntrustworthyData<Any>>()
|
||||||
|
for ((sessionInternal, requestAndMessage) in receivedMessages) {
|
||||||
|
val message = requestAndMessage.message.confirmReceiveType(requestAndMessage.request)
|
||||||
|
result[sessionInternal.flowSession] = message.checkPayloadIs(requestAndMessage.request.userReceiveType as Class<out Any>)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method will suspend the state machine and wait for incoming session init response from other party.
|
* This method will suspend the state machine and wait for incoming session init response from other party.
|
||||||
*/
|
*/
|
||||||
@ -362,10 +382,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
|
|
||||||
private fun createNewSession(
|
private fun createNewSession(
|
||||||
otherParty: Party,
|
otherParty: Party,
|
||||||
|
flowSession: FlowSession,
|
||||||
sessionFlow: FlowLogic<*>
|
sessionFlow: FlowLogic<*>
|
||||||
) {
|
) {
|
||||||
logger.trace { "Creating a new session with $otherParty" }
|
logger.trace { "Creating a new session with $otherParty" }
|
||||||
val session = FlowSessionInternal(sessionFlow, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty))
|
val session = FlowSessionInternal(sessionFlow, flowSession, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty))
|
||||||
openSessions[Pair(sessionFlow, otherParty)] = session
|
openSessions[Pair(sessionFlow, otherParty)] = session
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -402,6 +423,13 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest)
|
return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend {
|
||||||
|
@Suspendable
|
||||||
|
override fun invoke(request: FlowIORequest) {
|
||||||
|
suspend(request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> {
|
private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> {
|
||||||
val polledMessage = session.receivedMessages.poll()
|
val polledMessage = session.receivedMessages.poll()
|
||||||
|
@ -15,7 +15,11 @@ import net.corda.core.crypto.SecureHash
|
|||||||
import net.corda.core.crypto.random63BitValue
|
import net.corda.core.crypto.random63BitValue
|
||||||
import net.corda.core.flows.*
|
import net.corda.core.flows.*
|
||||||
import net.corda.core.identity.Party
|
import net.corda.core.identity.Party
|
||||||
import net.corda.core.internal.*
|
import net.corda.core.internal.FlowStateMachine
|
||||||
|
import net.corda.core.internal.ThreadBox
|
||||||
|
import net.corda.core.internal.bufferUntilSubscribed
|
||||||
|
import net.corda.core.internal.castIfPossible
|
||||||
|
import net.corda.core.internal.uncheckedCast
|
||||||
import net.corda.core.messaging.DataFeed
|
import net.corda.core.messaging.DataFeed
|
||||||
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
|
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
|
||||||
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
|
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
|
||||||
@ -342,8 +346,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
// commit but a counterparty flow has ended with an error (in which case our flow also has to end)
|
// commit but a counterparty flow has ended with an error (in which case our flow also has to end)
|
||||||
private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean {
|
private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean {
|
||||||
val waitingForResponse = session.fiber.waitingForResponse
|
val waitingForResponse = session.fiber.waitingForResponse
|
||||||
return (waitingForResponse as? ReceiveRequest<*>)?.session === session ||
|
return waitingForResponse?.shouldResume(message, session) ?: false
|
||||||
waitingForResponse is WaitForLedgerCommit && message is ErrorSessionEnd
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) {
|
private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) {
|
||||||
@ -362,6 +365,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
}
|
}
|
||||||
val session = FlowSessionInternal(
|
val session = FlowSessionInternal(
|
||||||
flow,
|
flow,
|
||||||
|
flowSession,
|
||||||
random63BitValue(),
|
random63BitValue(),
|
||||||
sender,
|
sender,
|
||||||
FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName)))
|
FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName)))
|
||||||
|
@ -51,6 +51,7 @@ import net.corda.testing.resetTestSerialization
|
|||||||
import net.corda.testing.testNodeConfiguration
|
import net.corda.testing.testNodeConfiguration
|
||||||
import org.apache.activemq.artemis.utils.ReusableLatch
|
import org.apache.activemq.artemis.utils.ReusableLatch
|
||||||
import org.slf4j.Logger
|
import org.slf4j.Logger
|
||||||
|
import java.io.Closeable
|
||||||
import java.math.BigInteger
|
import java.math.BigInteger
|
||||||
import java.nio.file.Path
|
import java.nio.file.Path
|
||||||
import java.security.KeyPair
|
import java.security.KeyPair
|
||||||
@ -81,12 +82,11 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
|
|||||||
servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy =
|
servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy =
|
||||||
InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random(),
|
InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random(),
|
||||||
private val defaultFactory: Factory<*> = MockNetwork.DefaultFactory,
|
private val defaultFactory: Factory<*> = MockNetwork.DefaultFactory,
|
||||||
private val initialiseSerialization: Boolean = true) {
|
private val initialiseSerialization: Boolean = true) : Closeable {
|
||||||
companion object {
|
companion object {
|
||||||
// TODO In future PR we're removing the concept of network map node so the details of this mock are not important.
|
// TODO In future PR we're removing the concept of network map node so the details of this mock are not important.
|
||||||
val MOCK_NET_MAP = Party(CordaX500Name(organisation = "Mock Network Map", locality = "Madrid", country = "ES"), DUMMY_KEY_1.public)
|
val MOCK_NET_MAP = Party(CordaX500Name(organisation = "Mock Network Map", locality = "Madrid", country = "ES"), DUMMY_KEY_1.public)
|
||||||
}
|
}
|
||||||
|
|
||||||
var nextNodeId = 0
|
var nextNodeId = 0
|
||||||
private set
|
private set
|
||||||
private val filesystem = Jimfs.newFileSystem(unix())
|
private val filesystem = Jimfs.newFileSystem(unix())
|
||||||
@ -445,4 +445,17 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
|
|||||||
fun waitQuiescent() {
|
fun waitQuiescent() {
|
||||||
busyLatch.await()
|
busyLatch.await()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun close() {
|
||||||
|
stopNodes()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun network(nodesCount: Int, action: MockNetwork.(nodes: List<StartedNode<MockNetwork.MockNode>>, notary: StartedNode<MockNetwork.MockNode>) -> Unit) {
|
||||||
|
MockNetwork().use {
|
||||||
|
it.runNetwork()
|
||||||
|
val notary = it.createNotaryNode()
|
||||||
|
val nodes = (1..nodesCount).map { _ -> it.createPartyNode() }
|
||||||
|
action(it, nodes, notary)
|
||||||
|
}
|
||||||
}
|
}
|
Loading…
Reference in New Issue
Block a user