mirror of
https://github.com/corda/corda.git
synced 2025-06-16 22:28:15 +00:00
[CORDA-683] Enable receiveAll()
from Flows.
This commit is contained in:
committed by
GitHub
parent
20a30b30da
commit
29a101c378
@ -1,5 +1,6 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.crypto.SecureHash
|
||||
|
||||
interface FlowIORequest {
|
||||
@ -8,7 +9,9 @@ interface FlowIORequest {
|
||||
val stackTraceInCaseOfProblems: StackSnapshot
|
||||
}
|
||||
|
||||
interface WaitingRequest : FlowIORequest
|
||||
interface WaitingRequest : FlowIORequest {
|
||||
fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean
|
||||
}
|
||||
|
||||
interface SessionedFlowIORequest : FlowIORequest {
|
||||
val session: FlowSessionInternal
|
||||
@ -21,6 +24,8 @@ interface SendRequest : SessionedFlowIORequest {
|
||||
interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest, WaitingRequest {
|
||||
val receiveType: Class<T>
|
||||
val userReceiveType: Class<*>?
|
||||
|
||||
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session
|
||||
}
|
||||
|
||||
data class SendAndReceive<T : SessionMessage>(override val session: FlowSessionInternal,
|
||||
@ -38,6 +43,63 @@ data class ReceiveOnly<T : SessionMessage>(override val session: FlowSessionInte
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
}
|
||||
|
||||
class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingRequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
|
||||
private fun isComplete(received: LinkedHashMap<FlowSessionInternal, RequestMessage>): Boolean {
|
||||
return received.keys == requests.map { it.session }.toSet()
|
||||
}
|
||||
private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) }
|
||||
|
||||
private fun hasSuccessfulEndMessage(it: ReceiveRequest<SessionData>): Boolean {
|
||||
return it.session.receivedMessages.map { it.message }.any { it is SessionData || it is SessionEnd }
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
fun suspendAndExpectReceive(suspend: Suspend): Map<FlowSessionInternal, RequestMessage> {
|
||||
val receivedMessages = LinkedHashMap<FlowSessionInternal, RequestMessage>()
|
||||
|
||||
poll(receivedMessages)
|
||||
return if (isComplete(receivedMessages)) {
|
||||
receivedMessages
|
||||
} else {
|
||||
suspend(this)
|
||||
poll(receivedMessages)
|
||||
if (isComplete(receivedMessages)) {
|
||||
receivedMessages
|
||||
} else {
|
||||
throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a ${it.receiveType.simpleName} but instead got nothing for $it." }.joinToString { "\n" })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
interface Suspend {
|
||||
@Suspendable
|
||||
operator fun invoke(request: FlowIORequest)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun poll(receivedMessages: LinkedHashMap<FlowSessionInternal, RequestMessage>) {
|
||||
return requests.filter { it.session !in receivedMessages.keys }.forEach { request ->
|
||||
poll(request)?.let {
|
||||
receivedMessages[request.session] = RequestMessage(request, it)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun poll(request: ReceiveRequest<SessionData>): ReceivedSessionMessage<*>? {
|
||||
return request.session.receivedMessages.poll()
|
||||
}
|
||||
|
||||
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant()
|
||||
|
||||
private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session }
|
||||
|
||||
data class RequestMessage(val request: ReceiveRequest<SessionData>, val message: ReceivedSessionMessage<*>)
|
||||
}
|
||||
|
||||
data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
@ -46,6 +108,8 @@ data class SendOnly(override val session: FlowSessionInternal, override val mess
|
||||
data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
|
||||
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message is ErrorSessionEnd
|
||||
}
|
||||
|
||||
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")
|
||||
|
@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.core.flows.FlowInfo
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.FlowSession
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.node.services.statemachine.FlowSessionState.Initiated
|
||||
import net.corda.node.services.statemachine.FlowSessionState.Initiating
|
||||
@ -15,6 +16,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
|
||||
// TODO rename this
|
||||
class FlowSessionInternal(
|
||||
val flow: FlowLogic<*>,
|
||||
val flowSession : FlowSession,
|
||||
val ourSessionId: Long,
|
||||
val initiatingParty: Party?,
|
||||
var state: FlowSessionState,
|
||||
|
@ -12,9 +12,13 @@ import net.corda.core.crypto.random63BitValue
|
||||
import net.corda.core.flows.*
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.identity.PartyAndCertificate
|
||||
import net.corda.core.internal.*
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.core.internal.abbreviate
|
||||
import net.corda.core.internal.concurrent.OpenFuture
|
||||
import net.corda.core.internal.concurrent.openFuture
|
||||
import net.corda.core.internal.isRegularFile
|
||||
import net.corda.core.internal.staticField
|
||||
import net.corda.core.internal.uncheckedCast
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.core.utilities.*
|
||||
import net.corda.node.services.api.FlowAppAuditEvent
|
||||
@ -171,8 +175,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
"@${InitiatingFlow::class.java.simpleName} sub-flow."
|
||||
)
|
||||
}
|
||||
createNewSession(otherParty, sessionFlow)
|
||||
val flowSession = FlowSessionImpl(otherParty)
|
||||
createNewSession(otherParty, flowSession, sessionFlow)
|
||||
flowSession.stateMachine = this
|
||||
flowSession.sessionFlow = sessionFlow
|
||||
return flowSession
|
||||
@ -299,6 +303,22 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
FlowStackSnapshotFactory.instance.persistAsJsonFile(flowClass, serviceHub.configuration.baseDirectory, id)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>> {
|
||||
val requests = ArrayList<ReceiveOnly<SessionData>>()
|
||||
for ((session, receiveType) in sessions) {
|
||||
val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow)
|
||||
requests.add(ReceiveOnly(sessionInternal, SessionData::class.java, receiveType))
|
||||
}
|
||||
val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend)
|
||||
val result = LinkedHashMap<FlowSession, UntrustworthyData<Any>>()
|
||||
for ((sessionInternal, requestAndMessage) in receivedMessages) {
|
||||
val message = requestAndMessage.message.confirmReceiveType(requestAndMessage.request)
|
||||
result[sessionInternal.flowSession] = message.checkPayloadIs(requestAndMessage.request.userReceiveType as Class<out Any>)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* This method will suspend the state machine and wait for incoming session init response from other party.
|
||||
*/
|
||||
@ -362,10 +382,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
|
||||
private fun createNewSession(
|
||||
otherParty: Party,
|
||||
flowSession: FlowSession,
|
||||
sessionFlow: FlowLogic<*>
|
||||
) {
|
||||
logger.trace { "Creating a new session with $otherParty" }
|
||||
val session = FlowSessionInternal(sessionFlow, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty))
|
||||
val session = FlowSessionInternal(sessionFlow, flowSession, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty))
|
||||
openSessions[Pair(sessionFlow, otherParty)] = session
|
||||
}
|
||||
|
||||
@ -402,6 +423,13 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest)
|
||||
}
|
||||
|
||||
private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend {
|
||||
@Suspendable
|
||||
override fun invoke(request: FlowIORequest) {
|
||||
suspend(request)
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> {
|
||||
val polledMessage = session.receivedMessages.poll()
|
||||
|
@ -15,7 +15,11 @@ import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.random63BitValue
|
||||
import net.corda.core.flows.*
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.*
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.core.internal.ThreadBox
|
||||
import net.corda.core.internal.bufferUntilSubscribed
|
||||
import net.corda.core.internal.castIfPossible
|
||||
import net.corda.core.internal.uncheckedCast
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
|
||||
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
|
||||
@ -342,8 +346,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
// commit but a counterparty flow has ended with an error (in which case our flow also has to end)
|
||||
private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean {
|
||||
val waitingForResponse = session.fiber.waitingForResponse
|
||||
return (waitingForResponse as? ReceiveRequest<*>)?.session === session ||
|
||||
waitingForResponse is WaitForLedgerCommit && message is ErrorSessionEnd
|
||||
return waitingForResponse?.shouldResume(message, session) ?: false
|
||||
}
|
||||
|
||||
private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) {
|
||||
@ -362,6 +365,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
}
|
||||
val session = FlowSessionInternal(
|
||||
flow,
|
||||
flowSession,
|
||||
random63BitValue(),
|
||||
sender,
|
||||
FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName)))
|
||||
|
Reference in New Issue
Block a user