mirror of
https://github.com/corda/corda.git
synced 2024-12-20 05:28:21 +00:00
Some clean up of the flow code
This commit is contained in:
parent
95a33168d8
commit
e589031d4b
@ -27,18 +27,18 @@ import rx.Observable
|
||||
*/
|
||||
abstract class FlowLogic<out T> {
|
||||
|
||||
/** Reference to the [Fiber] instance that is the top level controller for the entire flow. */
|
||||
lateinit var fsm: FlowStateMachine<*>
|
||||
/** Reference to the [FlowStateMachine] instance that is the top level controller for the entire flow. */
|
||||
lateinit var stateMachine: FlowStateMachine<*>
|
||||
|
||||
/** This is where you should log things to. */
|
||||
val logger: Logger get() = fsm.logger
|
||||
val logger: Logger get() = stateMachine.logger
|
||||
|
||||
/**
|
||||
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is
|
||||
* only available once the flow has started, which means it cannnot be accessed in the constructor. Either
|
||||
* access this lazily or from inside [call].
|
||||
*/
|
||||
val serviceHub: ServiceHub get() = fsm.serviceHub
|
||||
val serviceHub: ServiceHub get() = stateMachine.serviceHub
|
||||
|
||||
private var sessionFlow: FlowLogic<*> = this
|
||||
|
||||
@ -56,19 +56,19 @@ abstract class FlowLogic<out T> {
|
||||
|
||||
@Suspendable
|
||||
fun <T : Any> sendAndReceive(receiveType: Class<T>, otherParty: Party, payload: Any): UntrustworthyData<T> {
|
||||
return fsm.sendAndReceive(otherParty, payload, receiveType, sessionFlow)
|
||||
return stateMachine.sendAndReceive(receiveType, otherParty, payload, sessionFlow)
|
||||
}
|
||||
|
||||
inline fun <reified T : Any> receive(otherParty: Party): UntrustworthyData<T> = receive(T::class.java, otherParty)
|
||||
|
||||
@Suspendable
|
||||
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party): UntrustworthyData<T> {
|
||||
return fsm.receive(otherParty, receiveType, sessionFlow)
|
||||
return stateMachine.receive(receiveType, otherParty, sessionFlow)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
fun send(otherParty: Party, payload: Any) {
|
||||
fsm.send(otherParty, payload, sessionFlow)
|
||||
stateMachine.send(otherParty, payload, sessionFlow)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -82,7 +82,7 @@ abstract class FlowLogic<out T> {
|
||||
// TODO shareParentSessions is a bit too low-level and perhaps can be expresed in a better way
|
||||
@Suspendable
|
||||
fun <R> subFlow(subLogic: FlowLogic<R>, shareParentSessions: Boolean = false): R {
|
||||
subLogic.fsm = fsm
|
||||
subLogic.stateMachine = stateMachine
|
||||
maybeWireUpProgressTracking(subLogic)
|
||||
if (shareParentSessions) {
|
||||
subLogic.sessionFlow = this
|
||||
@ -127,5 +127,4 @@ abstract class FlowLogic<out T> {
|
||||
Pair(it.currentStep.toString(), it.changes.map { it.toString() })
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -28,13 +28,13 @@ data class StateMachineRunId private constructor(val uuid: UUID) {
|
||||
*/
|
||||
interface FlowStateMachine<R> {
|
||||
@Suspendable
|
||||
fun <T : Any> sendAndReceive(otherParty: Party,
|
||||
fun <T : Any> sendAndReceive(receiveType: Class<T>,
|
||||
otherParty: Party,
|
||||
payload: Any,
|
||||
receiveType: Class<T>,
|
||||
sessionFlow: FlowLogic<*>): UntrustworthyData<T>
|
||||
|
||||
@Suspendable
|
||||
fun <T : Any> receive(otherParty: Party, receiveType: Class<T>, sessionFlow: FlowLogic<*>): UntrustworthyData<T>
|
||||
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData<T>
|
||||
|
||||
@Suspendable
|
||||
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>)
|
||||
@ -48,4 +48,4 @@ interface FlowStateMachine<R> {
|
||||
val resultFuture: ListenableFuture<R>
|
||||
}
|
||||
|
||||
class FlowSessionException(message: String) : Exception(message)
|
||||
class FlowException(message: String) : RuntimeException(message)
|
||||
|
@ -59,7 +59,7 @@ class CashFlow(val command: CashCommand, override val progressTracker: ProgressT
|
||||
val flow = FinalityFlow(tx, setOf(req.recipient))
|
||||
subFlow(flow)
|
||||
return CashFlowResult.Success(
|
||||
fsm.id,
|
||||
stateMachine.id,
|
||||
tx,
|
||||
"Cash payment transaction generated"
|
||||
)
|
||||
@ -95,7 +95,7 @@ class CashFlow(val command: CashCommand, override val progressTracker: ProgressT
|
||||
val tx = builder.toSignedTransaction(checkSufficientSignatures = false)
|
||||
subFlow(FinalityFlow(tx, participants))
|
||||
return CashFlowResult.Success(
|
||||
fsm.id,
|
||||
stateMachine.id,
|
||||
tx,
|
||||
"Cash destruction transaction generated"
|
||||
)
|
||||
@ -116,7 +116,7 @@ class CashFlow(val command: CashCommand, override val progressTracker: ProgressT
|
||||
// Issuance transactions do not need to be notarised, so we can skip directly to broadcasting it
|
||||
subFlow(BroadcastTransactionFlow(tx, setOf(req.recipient)))
|
||||
return CashFlowResult.Success(
|
||||
fsm.id,
|
||||
stateMachine.id,
|
||||
tx,
|
||||
"Cash issuance completed"
|
||||
)
|
||||
|
@ -54,7 +54,7 @@ class IssuerFlowTest {
|
||||
private fun runIssuerAndIssueRequester(amount: Amount<Currency>, issueToPartyAndRef: PartyAndReference) : RunResult {
|
||||
val issuerFuture = bankOfCordaNode.initiateSingleShotFlow(IssuerFlow.IssuanceRequester::class) {
|
||||
otherParty -> IssuerFlow.Issuer(issueToPartyAndRef.party)
|
||||
}.map { it.fsm }
|
||||
}.map { it.stateMachine }
|
||||
|
||||
val issueRequest = IssuanceRequester(amount, issueToPartyAndRef.party, issueToPartyAndRef.reference, bankOfCordaNode.info.legalIdentity)
|
||||
val issueRequestResultFuture = bankClientNode.services.startFlow(issueRequest).resultFuture
|
||||
|
@ -1,7 +1,6 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.node.services.statemachine.StateMachineManager.FlowSession
|
||||
import net.corda.node.services.statemachine.StateMachineManager.SessionMessage
|
||||
|
||||
// TODO revisit when Kotlin 1.1 is released and data classes can extend other classes
|
||||
interface FlowIORequest {
|
@ -7,15 +7,16 @@ import co.paralleluniverse.strands.Strand
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
import net.corda.core.crypto.Party
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.FlowSessionException
|
||||
import net.corda.core.flows.FlowStateMachine
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.random63BitValue
|
||||
import net.corda.core.utilities.UntrustworthyData
|
||||
import net.corda.core.utilities.trace
|
||||
import net.corda.node.services.api.ServiceHubInternal
|
||||
import net.corda.node.services.statemachine.StateMachineManager.*
|
||||
import net.corda.node.services.statemachine.StateMachineManager.FlowSession
|
||||
import net.corda.node.services.statemachine.StateMachineManager.FlowSessionState
|
||||
import net.corda.node.utilities.StrandLocalTransactionManager
|
||||
import net.corda.node.utilities.createDatabaseTransaction
|
||||
import net.corda.node.utilities.databaseTransaction
|
||||
@ -31,7 +32,6 @@ import java.util.concurrent.ExecutionException
|
||||
class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
val logic: FlowLogic<R>,
|
||||
scheduler: FiberScheduler) : Fiber<R>("flow", scheduler), FlowStateMachine<R> {
|
||||
|
||||
companion object {
|
||||
// Used to work around a small limitation in Quasar.
|
||||
private val QUASAR_UNBLOCKER = run {
|
||||
@ -76,7 +76,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
internal val openSessions = HashMap<Pair<FlowLogic<*>, Party>, FlowSession>()
|
||||
|
||||
init {
|
||||
logic.fsm = this
|
||||
logic.stateMachine = this
|
||||
name = id.toString()
|
||||
}
|
||||
|
||||
@ -120,9 +120,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun <T : Any> sendAndReceive(otherParty: Party,
|
||||
override fun <T : Any> sendAndReceive(receiveType: Class<T>,
|
||||
otherParty: Party,
|
||||
payload: Any,
|
||||
receiveType: Class<T>,
|
||||
sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
|
||||
val (session, new) = getSession(otherParty, sessionFlow, payload)
|
||||
val receivedSessionData = if (new) {
|
||||
@ -132,16 +132,15 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
val sendSessionData = createSessionData(session, payload)
|
||||
sendAndReceiveInternal<SessionData>(session, sendSessionData)
|
||||
}
|
||||
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
|
||||
return receivedSessionData.checkPayloadIs(receiveType)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun <T : Any> receive(otherParty: Party,
|
||||
receiveType: Class<T>,
|
||||
override fun <T : Any> receive(receiveType: Class<T>,
|
||||
otherParty: Party,
|
||||
sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
|
||||
val session = getSession(otherParty, sessionFlow, null).first
|
||||
val receivedSessionData = receiveInternal<SessionData>(session)
|
||||
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
|
||||
return receiveInternal<SessionData>(session).checkPayloadIs(receiveType)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
@ -156,8 +155,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
private fun createSessionData(session: FlowSession, payload: Any): SessionData {
|
||||
val sessionState = session.state
|
||||
val peerSessionId = when (sessionState) {
|
||||
is StateMachineManager.FlowSessionState.Initiating -> throw IllegalStateException("We've somehow held onto an unconfirmed session: $session")
|
||||
is StateMachineManager.FlowSessionState.Initiated -> sessionState.peerSessionId
|
||||
is FlowSessionState.Initiating -> throw IllegalStateException("We've somehow held onto an unconfirmed session: $session")
|
||||
is FlowSessionState.Initiated -> sessionState.peerSessionId
|
||||
}
|
||||
return SessionData(peerSessionId, payload)
|
||||
}
|
||||
@ -167,16 +166,13 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
suspend(SendOnly(session, message))
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private inline fun <reified M : SessionMessage> receiveInternal(session: FlowSession): M {
|
||||
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)).message
|
||||
private inline fun <reified M : ExistingSessionMessage> receiveInternal(session: FlowSession): ReceivedSessionMessage<M> {
|
||||
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java))
|
||||
}
|
||||
|
||||
private inline fun <reified M : SessionMessage> sendAndReceiveInternal(session: FlowSession, message: SessionMessage): M {
|
||||
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)).message
|
||||
}
|
||||
|
||||
private inline fun <reified M : SessionMessage> sendAndReceiveInternalWithParty(session: FlowSession, message: SessionMessage): ReceivedSessionMessage<M> {
|
||||
private inline fun <reified M : ExistingSessionMessage> sendAndReceiveInternal(
|
||||
session: FlowSession,
|
||||
message: SessionMessage): ReceivedSessionMessage<M> {
|
||||
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java))
|
||||
}
|
||||
|
||||
@ -203,20 +199,21 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
openSessions[Pair(sessionFlow, otherParty)] = session
|
||||
val counterpartyFlow = sessionFlow.getCounterpartyMarker(otherParty).name
|
||||
val sessionInit = SessionInit(session.ourSessionId, counterpartyFlow, firstPayload)
|
||||
val (peerParty, sessionInitResponse) = sendAndReceiveInternalWithParty<SessionInitResponse>(session, sessionInit)
|
||||
val (peerParty, sessionInitResponse) = sendAndReceiveInternal<SessionInitResponse>(session, sessionInit)
|
||||
if (sessionInitResponse is SessionConfirm) {
|
||||
require(session.state is FlowSessionState.Initiating)
|
||||
session.state = FlowSessionState.Initiated(peerParty, sessionInitResponse.initiatedSessionId)
|
||||
return session
|
||||
} else {
|
||||
sessionInitResponse as SessionReject
|
||||
throw FlowSessionException("Party $otherParty rejected session attempt: ${sessionInitResponse.errorMessage}")
|
||||
throw FlowException("Party $otherParty rejected session request: ${sessionInitResponse.errorMessage}")
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun <M : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
|
||||
fun getReceivedMessage(): ReceivedSessionMessage<ExistingSessionMessage>? = receiveRequest.session.receivedMessages.poll()
|
||||
private fun <M : ExistingSessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
|
||||
val session = receiveRequest.session
|
||||
fun getReceivedMessage(): ReceivedSessionMessage<ExistingSessionMessage>? = session.receivedMessages.poll()
|
||||
|
||||
val polledMessage = getReceivedMessage()
|
||||
val receivedMessage = if (polledMessage != null) {
|
||||
@ -228,17 +225,21 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
} else {
|
||||
// Suspend while we wait for a receive
|
||||
suspend(receiveRequest)
|
||||
getReceivedMessage()
|
||||
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest")
|
||||
getReceivedMessage() ?:
|
||||
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead " +
|
||||
"got nothing: $receiveRequest")
|
||||
}
|
||||
|
||||
if (receivedMessage.message is SessionEnd) {
|
||||
openSessions.values.remove(receiveRequest.session)
|
||||
throw FlowSessionException("Counterparty on ${receiveRequest.session.state.sendToParty} has prematurely ended on $receiveRequest")
|
||||
openSessions.values.remove(session)
|
||||
throw FlowException("Party ${session.state.sendToParty} has ended their flow but we were expecting to " +
|
||||
"receive ${receiveRequest.receiveType.simpleName} from them")
|
||||
} else if (receiveRequest.receiveType.isInstance(receivedMessage.message)) {
|
||||
return ReceivedSessionMessage(receivedMessage.sendingParty, receiveRequest.receiveType.cast(receivedMessage.message))
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
return receivedMessage as ReceivedSessionMessage<M>
|
||||
} else {
|
||||
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $receiveRequest")
|
||||
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead got " +
|
||||
"${receivedMessage.message}: $receiveRequest")
|
||||
}
|
||||
}
|
||||
|
||||
@ -292,5 +293,4 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
logger.error("Error during resume", t)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,43 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import net.corda.core.abbreviate
|
||||
import net.corda.core.crypto.Party
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.core.utilities.UntrustworthyData
|
||||
|
||||
interface SessionMessage
|
||||
|
||||
interface ExistingSessionMessage : SessionMessage {
|
||||
val recipientSessionId: Long
|
||||
}
|
||||
|
||||
data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage
|
||||
|
||||
interface SessionInitResponse : ExistingSessionMessage
|
||||
|
||||
data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse {
|
||||
override val recipientSessionId: Long get() = initiatorSessionId
|
||||
}
|
||||
|
||||
data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse {
|
||||
override val recipientSessionId: Long get() = initiatorSessionId
|
||||
}
|
||||
|
||||
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage {
|
||||
override fun toString(): String {
|
||||
return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})"
|
||||
}
|
||||
}
|
||||
|
||||
data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage
|
||||
|
||||
data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M)
|
||||
|
||||
fun <T> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
|
||||
if (type.isInstance(message.payload)) {
|
||||
return UntrustworthyData(type.cast(message.payload))
|
||||
} else {
|
||||
throw FlowException("We were expecting a ${type.name} from $sender but we instead got a " +
|
||||
"${message.payload.javaClass.name} (${message.payload})")
|
||||
}
|
||||
}
|
@ -9,15 +9,19 @@ import com.esotericsoftware.kryo.Kryo
|
||||
import com.google.common.annotations.VisibleForTesting
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import kotlinx.support.jdk8.collections.removeIf
|
||||
import net.corda.core.*
|
||||
import net.corda.core.ThreadBox
|
||||
import net.corda.core.bufferUntilSubscribed
|
||||
import net.corda.core.crypto.Party
|
||||
import net.corda.core.crypto.commonName
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.FlowStateMachine
|
||||
import net.corda.core.flows.StateMachineRunId
|
||||
import net.corda.core.messaging.ReceivedMessage
|
||||
import net.corda.core.messaging.TopicSession
|
||||
import net.corda.core.messaging.send
|
||||
import net.corda.core.random63BitValue
|
||||
import net.corda.core.serialization.*
|
||||
import net.corda.core.then
|
||||
import net.corda.core.utilities.ProgressTracker
|
||||
import net.corda.core.utilities.debug
|
||||
import net.corda.core.utilities.loggerFor
|
||||
@ -89,8 +93,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
val stateMachines = LinkedHashMap<FlowStateMachineImpl<*>, Checkpoint>()
|
||||
val changesPublisher = PublishSubject.create<Change>()
|
||||
|
||||
fun notifyChangeObservers(psm: FlowStateMachineImpl<*>, addOrRemove: AddOrRemove) {
|
||||
changesPublisher.bufferUntilDatabaseCommit().onNext(Change(psm.logic, addOrRemove, psm.id))
|
||||
fun notifyChangeObservers(fiber: FlowStateMachineImpl<*>, addOrRemove: AddOrRemove) {
|
||||
changesPublisher.bufferUntilDatabaseCommit().onNext(Change(fiber.logic, addOrRemove, fiber.id))
|
||||
}
|
||||
})
|
||||
|
||||
@ -125,7 +129,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
stateMachines.keys
|
||||
.map { it.logic }
|
||||
.filterIsInstance(flowClass)
|
||||
.map { it to (it.fsm as FlowStateMachineImpl<T>).resultFuture }
|
||||
.map { it to (it.stateMachine as FlowStateMachineImpl<T>).resultFuture }
|
||||
}
|
||||
}
|
||||
|
||||
@ -207,17 +211,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
}
|
||||
serviceHub.networkService.addMessageHandler(sessionTopic) { message, reg ->
|
||||
executor.checkOnThread()
|
||||
val sessionMessage = message.data.deserialize<SessionMessage>()
|
||||
// TODO Look up the party with the full X.500 name instead of just the legal name
|
||||
val otherParty = serviceHub.networkMapCache.getNodeByLegalName(message.peer.commonName)?.legalIdentity
|
||||
if (otherParty != null) {
|
||||
when (sessionMessage) {
|
||||
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, otherParty)
|
||||
is SessionInit -> onSessionInit(sessionMessage, otherParty)
|
||||
}
|
||||
} else {
|
||||
logger.error("Unknown peer ${message.peer} in $sessionMessage")
|
||||
}
|
||||
onSessionMessage(message)
|
||||
}
|
||||
}
|
||||
|
||||
@ -230,26 +224,40 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
}
|
||||
}
|
||||
|
||||
private fun onExistingSessionMessage(message: ExistingSessionMessage, otherParty: Party) {
|
||||
private fun onSessionMessage(message: ReceivedMessage) {
|
||||
val sessionMessage = message.data.deserialize<SessionMessage>()
|
||||
// TODO Look up the party with the full X.500 name instead of just the legal name
|
||||
val sender = serviceHub.networkMapCache.getNodeByLegalName(message.peer.commonName)?.legalIdentity
|
||||
if (sender != null) {
|
||||
when (sessionMessage) {
|
||||
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender)
|
||||
is SessionInit -> onSessionInit(sessionMessage, sender)
|
||||
}
|
||||
} else {
|
||||
logger.error("Unknown peer ${message.peer} in $sessionMessage")
|
||||
}
|
||||
}
|
||||
|
||||
private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) {
|
||||
val session = openSessions[message.recipientSessionId]
|
||||
if (session != null) {
|
||||
session.psm.logger.trace { "Received $message on $session" }
|
||||
session.fiber.logger.trace { "Received $message on $session" }
|
||||
if (message is SessionEnd) {
|
||||
openSessions.remove(message.recipientSessionId)
|
||||
}
|
||||
session.receivedMessages += ReceivedSessionMessage(otherParty, message)
|
||||
session.receivedMessages += ReceivedSessionMessage(sender, message)
|
||||
if (session.waitingForResponse) {
|
||||
// We only want to resume once, so immediately reset the flag.
|
||||
session.waitingForResponse = false
|
||||
updateCheckpoint(session.psm)
|
||||
resumeFiber(session.psm)
|
||||
updateCheckpoint(session.fiber)
|
||||
resumeFiber(session.fiber)
|
||||
}
|
||||
} else {
|
||||
val peerParty = recentlyClosedSessions.remove(message.recipientSessionId)
|
||||
if (peerParty != null) {
|
||||
if (message is SessionConfirm) {
|
||||
logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" }
|
||||
sendSessionMessage(peerParty, SessionEnd(message.initiatedSessionId), null)
|
||||
sendSessionMessage(peerParty, SessionEnd(message.initiatedSessionId))
|
||||
} else {
|
||||
logger.trace { "Ignoring session end message for already closed session: $message" }
|
||||
}
|
||||
@ -259,32 +267,32 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
}
|
||||
}
|
||||
|
||||
private fun onSessionInit(sessionInit: SessionInit, otherParty: Party) {
|
||||
logger.trace { "Received $sessionInit $otherParty" }
|
||||
private fun onSessionInit(sessionInit: SessionInit, sender: Party) {
|
||||
logger.trace { "Received $sessionInit $sender" }
|
||||
val otherPartySessionId = sessionInit.initiatorSessionId
|
||||
try {
|
||||
val markerClass = Class.forName(sessionInit.flowName)
|
||||
val flowFactory = serviceHub.getFlowFactory(markerClass)
|
||||
if (flowFactory != null) {
|
||||
val flow = flowFactory(otherParty)
|
||||
val psm = createFiber(flow)
|
||||
val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(otherParty, otherPartySessionId))
|
||||
val flow = flowFactory(sender)
|
||||
val fiber = createFiber(flow)
|
||||
val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(sender, otherPartySessionId))
|
||||
if (sessionInit.firstPayload != null) {
|
||||
session.receivedMessages += ReceivedSessionMessage(otherParty, SessionData(session.ourSessionId, sessionInit.firstPayload))
|
||||
session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload))
|
||||
}
|
||||
openSessions[session.ourSessionId] = session
|
||||
psm.openSessions[Pair(flow, otherParty)] = session
|
||||
updateCheckpoint(psm)
|
||||
sendSessionMessage(otherParty, SessionConfirm(otherPartySessionId, session.ourSessionId), psm)
|
||||
psm.logger.debug { "Initiated from $sessionInit on $session" }
|
||||
startFiber(psm)
|
||||
fiber.openSessions[Pair(flow, sender)] = session
|
||||
updateCheckpoint(fiber)
|
||||
sendSessionMessage(sender, SessionConfirm(otherPartySessionId, session.ourSessionId), fiber)
|
||||
fiber.logger.debug { "Initiated from $sessionInit on $session" }
|
||||
startFiber(fiber)
|
||||
} else {
|
||||
logger.warn("Unknown flow marker class in $sessionInit")
|
||||
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}"), null)
|
||||
sendSessionMessage(sender, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}"))
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
logger.warn("Received invalid $sessionInit", e)
|
||||
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Unable to establish session"), null)
|
||||
sendSessionMessage(sender, SessionReject(otherPartySessionId, "Unable to establish session"))
|
||||
}
|
||||
}
|
||||
|
||||
@ -312,27 +320,27 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
return FlowStateMachineImpl(id, logic, scheduler).apply { initFiber(this) }
|
||||
}
|
||||
|
||||
private fun initFiber(psm: FlowStateMachineImpl<*>) {
|
||||
psm.database = database
|
||||
psm.serviceHub = serviceHub
|
||||
psm.actionOnSuspend = { ioRequest ->
|
||||
updateCheckpoint(psm)
|
||||
private fun initFiber(fiber: FlowStateMachineImpl<*>) {
|
||||
fiber.database = database
|
||||
fiber.serviceHub = serviceHub
|
||||
fiber.actionOnSuspend = { ioRequest ->
|
||||
updateCheckpoint(fiber)
|
||||
// We commit on the fibers transaction that was copied across ThreadLocals during suspend
|
||||
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
|
||||
psm.commitTransaction()
|
||||
fiber.commitTransaction()
|
||||
processIORequest(ioRequest)
|
||||
decrementLiveFibers()
|
||||
}
|
||||
psm.actionOnEnd = {
|
||||
fiber.actionOnEnd = {
|
||||
try {
|
||||
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||
fiber.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||
mutex.locked {
|
||||
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
|
||||
stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) }
|
||||
totalFinishedFlows.inc()
|
||||
unfinishedFibers.countDown()
|
||||
notifyChangeObservers(psm, AddOrRemove.REMOVE)
|
||||
notifyChangeObservers(fiber, AddOrRemove.REMOVE)
|
||||
}
|
||||
endAllFiberSessions(psm)
|
||||
endAllFiberSessions(fiber)
|
||||
} finally {
|
||||
decrementLiveFibers()
|
||||
}
|
||||
@ -340,16 +348,16 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
mutex.locked {
|
||||
totalStartedFlows.inc()
|
||||
unfinishedFibers.countUp()
|
||||
notifyChangeObservers(psm, AddOrRemove.ADD)
|
||||
notifyChangeObservers(fiber, AddOrRemove.ADD)
|
||||
}
|
||||
}
|
||||
|
||||
private fun endAllFiberSessions(psm: FlowStateMachineImpl<*>) {
|
||||
private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>) {
|
||||
openSessions.values.removeIf { session ->
|
||||
if (session.psm == psm) {
|
||||
if (session.fiber == fiber) {
|
||||
val initiatedState = session.state as? FlowSessionState.Initiated
|
||||
if (initiatedState != null) {
|
||||
sendSessionMessage(initiatedState.peerParty, SessionEnd(initiatedState.peerSessionId), psm)
|
||||
sendSessionMessage(initiatedState.peerParty, SessionEnd(initiatedState.peerSessionId), fiber)
|
||||
recentlyClosedSessions[session.ourSessionId] = initiatedState.peerParty
|
||||
}
|
||||
true
|
||||
@ -405,10 +413,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
return fiber
|
||||
}
|
||||
|
||||
private fun updateCheckpoint(psm: FlowStateMachineImpl<*>) {
|
||||
check(psm.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" }
|
||||
val newCheckpoint = Checkpoint(serializeFiber(psm))
|
||||
val previousCheckpoint = mutex.locked { stateMachines.put(psm, newCheckpoint) }
|
||||
private fun updateCheckpoint(fiber: FlowStateMachineImpl<*>) {
|
||||
check(fiber.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" }
|
||||
val newCheckpoint = Checkpoint(serializeFiber(fiber))
|
||||
val previousCheckpoint = mutex.locked { stateMachines.put(fiber, newCheckpoint) }
|
||||
if (previousCheckpoint != null) {
|
||||
checkpointStorage.removeCheckpoint(previousCheckpoint)
|
||||
}
|
||||
@ -416,13 +424,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
checkpointingMeter.mark()
|
||||
}
|
||||
|
||||
private fun resumeFiber(psm: FlowStateMachineImpl<*>) {
|
||||
private fun resumeFiber(fiber: FlowStateMachineImpl<*>) {
|
||||
// Avoid race condition when setting stopping to true and then checking liveFibers
|
||||
incrementLiveFibers()
|
||||
if (!stopping) executor.executeASAP {
|
||||
psm.resume(scheduler)
|
||||
fiber.resume(scheduler)
|
||||
} else {
|
||||
psm.logger.debug("Not resuming as SMM is stopping.")
|
||||
fiber.logger.debug("Not resuming as SMM is stopping.")
|
||||
decrementLiveFibers()
|
||||
}
|
||||
}
|
||||
@ -432,51 +440,23 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
if (ioRequest.message is SessionInit) {
|
||||
openSessions[ioRequest.session.ourSessionId] = ioRequest.session
|
||||
}
|
||||
sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.psm)
|
||||
sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber)
|
||||
if (ioRequest !is ReceiveRequest<*>) {
|
||||
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
|
||||
resumeFiber(ioRequest.session.psm)
|
||||
resumeFiber(ioRequest.session.fiber)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun sendSessionMessage(party: Party, message: SessionMessage, psm: FlowStateMachineImpl<*>?) {
|
||||
private fun sendSessionMessage(party: Party, message: SessionMessage, fiber: FlowStateMachineImpl<*>? = null) {
|
||||
val partyInfo = serviceHub.networkMapCache.getPartyInfo(party)
|
||||
?: throw IllegalArgumentException("Don't know about party $party")
|
||||
val address = serviceHub.networkService.getAddressOfParty(partyInfo)
|
||||
val logger = psm?.logger ?: logger
|
||||
val logger = fiber?.logger ?: logger
|
||||
logger.debug { "Sending $message to party $party, address: $address" }
|
||||
serviceHub.networkService.send(sessionTopic, message, address)
|
||||
}
|
||||
|
||||
data class ReceivedSessionMessage<out M : SessionMessage>(val sendingParty: Party, val message: M)
|
||||
|
||||
interface SessionMessage
|
||||
|
||||
interface ExistingSessionMessage : SessionMessage {
|
||||
val recipientSessionId: Long
|
||||
}
|
||||
|
||||
data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage
|
||||
|
||||
interface SessionInitResponse : ExistingSessionMessage
|
||||
|
||||
data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse {
|
||||
override val recipientSessionId: Long get() = initiatorSessionId
|
||||
}
|
||||
|
||||
data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse {
|
||||
override val recipientSessionId: Long get() = initiatorSessionId
|
||||
}
|
||||
|
||||
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage {
|
||||
override fun toString(): String {
|
||||
return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})"
|
||||
}
|
||||
}
|
||||
|
||||
data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage
|
||||
|
||||
/**
|
||||
* [FlowSessionState] describes the session's state.
|
||||
*
|
||||
@ -507,7 +487,6 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
@Volatile var waitingForResponse: Boolean = false
|
||||
) {
|
||||
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<ExistingSessionMessage>>()
|
||||
val psm: FlowStateMachineImpl<*> get() = flow.fsm as FlowStateMachineImpl<*>
|
||||
val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*>
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -415,10 +415,10 @@ class TwoPartyTradeFlowTests {
|
||||
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>): RunResult {
|
||||
val buyerFuture = bobNode.initiateSingleShotFlow(Seller::class) { otherParty ->
|
||||
Buyer(otherParty, notaryNode.info.notaryIdentity, 1000.DOLLARS, CommercialPaper.State::class.java)
|
||||
}.map { it.fsm }
|
||||
}.map { it.stateMachine }
|
||||
val seller = Seller(bobNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
|
||||
val sellerResultFuture = aliceNode.services.startFlow(seller).resultFuture
|
||||
return RunResult(buyerFuture, sellerResultFuture, seller.fsm.id)
|
||||
return RunResult(buyerFuture, sellerResultFuture, seller.stateMachine.id)
|
||||
}
|
||||
|
||||
private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError(
|
||||
|
@ -7,8 +7,8 @@ import net.corda.core.contracts.DOLLARS
|
||||
import net.corda.core.contracts.issuedBy
|
||||
import net.corda.core.crypto.Party
|
||||
import net.corda.core.crypto.generateKeyPair
|
||||
import net.corda.core.flows.FlowException
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.FlowSessionException
|
||||
import net.corda.core.getOrThrow
|
||||
import net.corda.core.random63BitValue
|
||||
import net.corda.core.serialization.OpaqueBytes
|
||||
@ -17,7 +17,6 @@ import net.corda.flows.CashCommand
|
||||
import net.corda.flows.CashFlow
|
||||
import net.corda.flows.NotaryFlow
|
||||
import net.corda.node.services.persistence.checkpoints
|
||||
import net.corda.node.services.statemachine.StateMachineManager.*
|
||||
import net.corda.node.utilities.databaseTransaction
|
||||
import net.corda.testing.expect
|
||||
import net.corda.testing.expectEvents
|
||||
@ -215,15 +214,15 @@ class StateMachineManagerTests {
|
||||
|
||||
assertSessionTransfers(node2,
|
||||
node1 sent sessionInit(SendFlow::class, payload) to node2,
|
||||
node2 sent sessionConfirm() to node1,
|
||||
node1 sent sessionEnd() to node2
|
||||
node2 sent sessionConfirm to node1,
|
||||
node1 sent sessionEnd to node2
|
||||
//There's no session end from the other flows as they're manually suspended
|
||||
)
|
||||
|
||||
assertSessionTransfers(node3,
|
||||
node1 sent sessionInit(SendFlow::class, payload) to node3,
|
||||
node3 sent sessionConfirm() to node1,
|
||||
node1 sent sessionEnd() to node3
|
||||
node3 sent sessionConfirm to node1,
|
||||
node1 sent sessionEnd to node3
|
||||
//There's no session end from the other flows as they're manually suspended
|
||||
)
|
||||
|
||||
@ -248,16 +247,16 @@ class StateMachineManagerTests {
|
||||
|
||||
assertSessionTransfers(node2,
|
||||
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
|
||||
node2 sent sessionConfirm() to node1,
|
||||
node2 sent sessionConfirm to node1,
|
||||
node2 sent sessionData(node2Payload) to node1,
|
||||
node2 sent sessionEnd() to node1
|
||||
node2 sent sessionEnd to node1
|
||||
)
|
||||
|
||||
assertSessionTransfers(node3,
|
||||
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node3,
|
||||
node3 sent sessionConfirm() to node1,
|
||||
node3 sent sessionConfirm to node1,
|
||||
node3 sent sessionData(node3Payload) to node1,
|
||||
node3 sent sessionEnd() to node1
|
||||
node3 sent sessionEnd to node1
|
||||
)
|
||||
}
|
||||
|
||||
@ -269,11 +268,11 @@ class StateMachineManagerTests {
|
||||
|
||||
assertSessionTransfers(
|
||||
node1 sent sessionInit(PingPongFlow::class, 10L) to node2,
|
||||
node2 sent sessionConfirm() to node1,
|
||||
node2 sent sessionConfirm to node1,
|
||||
node2 sent sessionData(20L) to node1,
|
||||
node1 sent sessionData(11L) to node2,
|
||||
node2 sent sessionData(21L) to node1,
|
||||
node1 sent sessionEnd() to node2
|
||||
node1 sent sessionEnd to node2
|
||||
)
|
||||
}
|
||||
|
||||
@ -333,11 +332,11 @@ class StateMachineManagerTests {
|
||||
node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { ExceptionFlow }
|
||||
val future = node1.services.startFlow(ReceiveThenSuspendFlow(node2.info.legalIdentity)).resultFuture
|
||||
net.runNetwork()
|
||||
assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowSessionException::class.java)
|
||||
assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowException::class.java)
|
||||
assertSessionTransfers(
|
||||
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
|
||||
node2 sent sessionConfirm() to node1,
|
||||
node2 sent sessionEnd() to node1
|
||||
node2 sent sessionConfirm to node1,
|
||||
node2 sent sessionEnd to node1
|
||||
)
|
||||
}
|
||||
|
||||
@ -358,11 +357,11 @@ class StateMachineManagerTests {
|
||||
|
||||
private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload)
|
||||
|
||||
private fun sessionConfirm() = SessionConfirm(0, 0)
|
||||
private val sessionConfirm = SessionConfirm(0, 0)
|
||||
|
||||
private fun sessionData(payload: Any) = SessionData(0, payload)
|
||||
|
||||
private fun sessionEnd() = SessionEnd(0)
|
||||
private val sessionEnd = SessionEnd(0)
|
||||
|
||||
private fun assertSessionTransfers(vararg expected: SessionTransfer) {
|
||||
assertThat(sessionTransfers).containsExactly(*expected)
|
||||
@ -462,5 +461,4 @@ class StateMachineManagerTests {
|
||||
private object ExceptionFlow : FlowLogic<Nothing>() {
|
||||
override fun call(): Nothing = throw Exception()
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -120,7 +120,7 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val acceptorTx = node2.initiateSingleShotFlow(Instigator::class) { Acceptor(it) }.flatMap {
|
||||
(it.fsm as FlowStateMachine<SignedTransaction>).resultFuture
|
||||
(it.stateMachine as FlowStateMachine<SignedTransaction>).resultFuture
|
||||
}
|
||||
|
||||
showProgressFor(listOf(node1, node2))
|
||||
|
@ -53,7 +53,7 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val buyerFuture = buyer.initiateSingleShotFlow(Seller::class) {
|
||||
Buyer(it, notary.info.notaryIdentity, amount, CommercialPaper.State::class.java)
|
||||
}.flatMap { (it.fsm as FlowStateMachine<SignedTransaction>).resultFuture }
|
||||
}.flatMap { (it.stateMachine as FlowStateMachine<SignedTransaction>).resultFuture }
|
||||
|
||||
val sellerKey = seller.services.legalIdentityKey
|
||||
val sellerFlow = Seller(
|
||||
|
@ -17,7 +17,9 @@ import net.corda.core.then
|
||||
import net.corda.core.utilities.ProgressTracker
|
||||
import net.corda.netmap.VisualiserViewModel.Style
|
||||
import net.corda.node.services.network.NetworkMapService
|
||||
import net.corda.node.services.statemachine.StateMachineManager
|
||||
import net.corda.node.services.statemachine.SessionConfirm
|
||||
import net.corda.node.services.statemachine.SessionEnd
|
||||
import net.corda.node.services.statemachine.SessionInit
|
||||
import net.corda.simulation.IRSSimulation
|
||||
import net.corda.simulation.Simulation
|
||||
import net.corda.testing.node.InMemoryMessagingNetwork
|
||||
@ -349,13 +351,12 @@ class NetworkMapVisualiser : Application() {
|
||||
// Network map push acknowledgements are boring.
|
||||
if (NetworkMapService.PUSH_ACK_FLOW_TOPIC in transfer.message.topicSession.topic) return false
|
||||
val message = transfer.message.data.deserialize<Any>()
|
||||
val messageClassType = message.javaClass.name
|
||||
when (messageClassType) {
|
||||
StateMachineManager.SessionEnd::class.java.name -> return false
|
||||
StateMachineManager.SessionConfirm::class.java.name -> return false
|
||||
StateMachineManager.SessionInit::class.java.name -> if ((message as StateMachineManager.SessionInit).firstPayload == null) return false
|
||||
return when (message) {
|
||||
is SessionEnd -> false
|
||||
is SessionConfirm -> false
|
||||
is SessionInit -> message.firstPayload != null
|
||||
else -> true
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user