[CORDA-683] Enable receiveAll() from Flows.

This commit is contained in:
Michele Sollecito
2017-10-09 13:46:37 +01:00
committed by GitHub
parent 20a30b30da
commit 29a101c378
11 changed files with 377 additions and 15 deletions

View File

@ -1,5 +1,6 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.SecureHash
interface FlowIORequest {
@ -8,7 +9,9 @@ interface FlowIORequest {
val stackTraceInCaseOfProblems: StackSnapshot
}
interface WaitingRequest : FlowIORequest
interface WaitingRequest : FlowIORequest {
fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean
}
interface SessionedFlowIORequest : FlowIORequest {
val session: FlowSessionInternal
@ -21,6 +24,8 @@ interface SendRequest : SessionedFlowIORequest {
interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest, WaitingRequest {
val receiveType: Class<T>
val userReceiveType: Class<*>?
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session
}
data class SendAndReceive<T : SessionMessage>(override val session: FlowSessionInternal,
@ -38,6 +43,63 @@ data class ReceiveOnly<T : SessionMessage>(override val session: FlowSessionInte
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
private fun isComplete(received: LinkedHashMap<FlowSessionInternal, RequestMessage>): Boolean {
return received.keys == requests.map { it.session }.toSet()
}
private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) }
private fun hasSuccessfulEndMessage(it: ReceiveRequest<SessionData>): Boolean {
return it.session.receivedMessages.map { it.message }.any { it is SessionData || it is SessionEnd }
}
@Suspendable
fun suspendAndExpectReceive(suspend: Suspend): Map<FlowSessionInternal, RequestMessage> {
val receivedMessages = LinkedHashMap<FlowSessionInternal, RequestMessage>()
poll(receivedMessages)
return if (isComplete(receivedMessages)) {
receivedMessages
} else {
suspend(this)
poll(receivedMessages)
if (isComplete(receivedMessages)) {
receivedMessages
} else {
throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a ${it.receiveType.simpleName} but instead got nothing for $it." }.joinToString { "\n" })
}
}
}
interface Suspend {
@Suspendable
operator fun invoke(request: FlowIORequest)
}
@Suspendable
private fun poll(receivedMessages: LinkedHashMap<FlowSessionInternal, RequestMessage>) {
return requests.filter { it.session !in receivedMessages.keys }.forEach { request ->
poll(request)?.let {
receivedMessages[request.session] = RequestMessage(request, it)
}
}
}
@Suspendable
private fun poll(request: ReceiveRequest<SessionData>): ReceivedSessionMessage<*>? {
return request.session.receivedMessages.poll()
}
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant()
private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session }
data class RequestMessage(val request: ReceiveRequest<SessionData>, val message: ReceivedSessionMessage<*>)
}
data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
@ -46,6 +108,8 @@ data class SendOnly(override val session: FlowSessionInternal, override val mess
data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message is ErrorSessionEnd
}
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")

View File

@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.identity.Party
import net.corda.node.services.statemachine.FlowSessionState.Initiated
import net.corda.node.services.statemachine.FlowSessionState.Initiating
@ -15,6 +16,7 @@ import java.util.concurrent.ConcurrentLinkedQueue
// TODO rename this
class FlowSessionInternal(
val flow: FlowLogic<*>,
val flowSession : FlowSession,
val ourSessionId: Long,
val initiatingParty: Party?,
var state: FlowSessionState,

View File

@ -12,9 +12,13 @@ import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.*
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.abbreviate
import net.corda.core.internal.concurrent.OpenFuture
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.isRegularFile
import net.corda.core.internal.staticField
import net.corda.core.internal.uncheckedCast
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.*
import net.corda.node.services.api.FlowAppAuditEvent
@ -171,8 +175,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
"@${InitiatingFlow::class.java.simpleName} sub-flow."
)
}
createNewSession(otherParty, sessionFlow)
val flowSession = FlowSessionImpl(otherParty)
createNewSession(otherParty, flowSession, sessionFlow)
flowSession.stateMachine = this
flowSession.sessionFlow = sessionFlow
return flowSession
@ -299,6 +303,22 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
FlowStackSnapshotFactory.instance.persistAsJsonFile(flowClass, serviceHub.configuration.baseDirectory, id)
}
@Suspendable
override fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>> {
val requests = ArrayList<ReceiveOnly<SessionData>>()
for ((session, receiveType) in sessions) {
val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow)
requests.add(ReceiveOnly(sessionInternal, SessionData::class.java, receiveType))
}
val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend)
val result = LinkedHashMap<FlowSession, UntrustworthyData<Any>>()
for ((sessionInternal, requestAndMessage) in receivedMessages) {
val message = requestAndMessage.message.confirmReceiveType(requestAndMessage.request)
result[sessionInternal.flowSession] = message.checkPayloadIs(requestAndMessage.request.userReceiveType as Class<out Any>)
}
return result
}
/**
* This method will suspend the state machine and wait for incoming session init response from other party.
*/
@ -362,10 +382,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
private fun createNewSession(
otherParty: Party,
flowSession: FlowSession,
sessionFlow: FlowLogic<*>
) {
logger.trace { "Creating a new session with $otherParty" }
val session = FlowSessionInternal(sessionFlow, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty))
val session = FlowSessionInternal(sessionFlow, flowSession, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty))
openSessions[Pair(sessionFlow, otherParty)] = session
}
@ -402,6 +423,13 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest)
}
private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend {
@Suspendable
override fun invoke(request: FlowIORequest) {
suspend(request)
}
}
@Suspendable
private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> {
val polledMessage = session.receivedMessages.poll()

View File

@ -15,7 +15,11 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.internal.*
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.castIfPossible
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
@ -342,8 +346,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
// commit but a counterparty flow has ended with an error (in which case our flow also has to end)
private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean {
val waitingForResponse = session.fiber.waitingForResponse
return (waitingForResponse as? ReceiveRequest<*>)?.session === session ||
waitingForResponse is WaitForLedgerCommit && message is ErrorSessionEnd
return waitingForResponse?.shouldResume(message, session) ?: false
}
private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) {
@ -362,6 +365,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
val session = FlowSessionInternal(
flow,
flowSession,
random63BitValue(),
sender,
FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName)))