Some clean up of the flow code

This commit is contained in:
Shams Asari 2017-01-11 10:21:54 +00:00
parent 95a33168d8
commit e589031d4b
13 changed files with 188 additions and 169 deletions

View File

@ -27,18 +27,18 @@ import rx.Observable
*/
abstract class FlowLogic<out T> {
/** Reference to the [Fiber] instance that is the top level controller for the entire flow. */
lateinit var fsm: FlowStateMachine<*>
/** Reference to the [FlowStateMachine] instance that is the top level controller for the entire flow. */
lateinit var stateMachine: FlowStateMachine<*>
/** This is where you should log things to. */
val logger: Logger get() = fsm.logger
val logger: Logger get() = stateMachine.logger
/**
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is
* only available once the flow has started, which means it cannnot be accessed in the constructor. Either
* access this lazily or from inside [call].
*/
val serviceHub: ServiceHub get() = fsm.serviceHub
val serviceHub: ServiceHub get() = stateMachine.serviceHub
private var sessionFlow: FlowLogic<*> = this
@ -56,19 +56,19 @@ abstract class FlowLogic<out T> {
@Suspendable
fun <T : Any> sendAndReceive(receiveType: Class<T>, otherParty: Party, payload: Any): UntrustworthyData<T> {
return fsm.sendAndReceive(otherParty, payload, receiveType, sessionFlow)
return stateMachine.sendAndReceive(receiveType, otherParty, payload, sessionFlow)
}
inline fun <reified T : Any> receive(otherParty: Party): UntrustworthyData<T> = receive(T::class.java, otherParty)
@Suspendable
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party): UntrustworthyData<T> {
return fsm.receive(otherParty, receiveType, sessionFlow)
return stateMachine.receive(receiveType, otherParty, sessionFlow)
}
@Suspendable
fun send(otherParty: Party, payload: Any) {
fsm.send(otherParty, payload, sessionFlow)
stateMachine.send(otherParty, payload, sessionFlow)
}
/**
@ -82,7 +82,7 @@ abstract class FlowLogic<out T> {
// TODO shareParentSessions is a bit too low-level and perhaps can be expresed in a better way
@Suspendable
fun <R> subFlow(subLogic: FlowLogic<R>, shareParentSessions: Boolean = false): R {
subLogic.fsm = fsm
subLogic.stateMachine = stateMachine
maybeWireUpProgressTracking(subLogic)
if (shareParentSessions) {
subLogic.sessionFlow = this
@ -127,5 +127,4 @@ abstract class FlowLogic<out T> {
Pair(it.currentStep.toString(), it.changes.map { it.toString() })
}
}
}

View File

@ -28,13 +28,13 @@ data class StateMachineRunId private constructor(val uuid: UUID) {
*/
interface FlowStateMachine<R> {
@Suspendable
fun <T : Any> sendAndReceive(otherParty: Party,
fun <T : Any> sendAndReceive(receiveType: Class<T>,
otherParty: Party,
payload: Any,
receiveType: Class<T>,
sessionFlow: FlowLogic<*>): UntrustworthyData<T>
@Suspendable
fun <T : Any> receive(otherParty: Party, receiveType: Class<T>, sessionFlow: FlowLogic<*>): UntrustworthyData<T>
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData<T>
@Suspendable
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>)
@ -48,4 +48,4 @@ interface FlowStateMachine<R> {
val resultFuture: ListenableFuture<R>
}
class FlowSessionException(message: String) : Exception(message)
class FlowException(message: String) : RuntimeException(message)

View File

@ -59,7 +59,7 @@ class CashFlow(val command: CashCommand, override val progressTracker: ProgressT
val flow = FinalityFlow(tx, setOf(req.recipient))
subFlow(flow)
return CashFlowResult.Success(
fsm.id,
stateMachine.id,
tx,
"Cash payment transaction generated"
)
@ -95,7 +95,7 @@ class CashFlow(val command: CashCommand, override val progressTracker: ProgressT
val tx = builder.toSignedTransaction(checkSufficientSignatures = false)
subFlow(FinalityFlow(tx, participants))
return CashFlowResult.Success(
fsm.id,
stateMachine.id,
tx,
"Cash destruction transaction generated"
)
@ -116,7 +116,7 @@ class CashFlow(val command: CashCommand, override val progressTracker: ProgressT
// Issuance transactions do not need to be notarised, so we can skip directly to broadcasting it
subFlow(BroadcastTransactionFlow(tx, setOf(req.recipient)))
return CashFlowResult.Success(
fsm.id,
stateMachine.id,
tx,
"Cash issuance completed"
)

View File

@ -54,7 +54,7 @@ class IssuerFlowTest {
private fun runIssuerAndIssueRequester(amount: Amount<Currency>, issueToPartyAndRef: PartyAndReference) : RunResult {
val issuerFuture = bankOfCordaNode.initiateSingleShotFlow(IssuerFlow.IssuanceRequester::class) {
otherParty -> IssuerFlow.Issuer(issueToPartyAndRef.party)
}.map { it.fsm }
}.map { it.stateMachine }
val issueRequest = IssuanceRequester(amount, issueToPartyAndRef.party, issueToPartyAndRef.reference, bankOfCordaNode.info.legalIdentity)
val issueRequestResultFuture = bankClientNode.services.startFlow(issueRequest).resultFuture

View File

@ -1,7 +1,6 @@
package net.corda.node.services.statemachine
import net.corda.node.services.statemachine.StateMachineManager.FlowSession
import net.corda.node.services.statemachine.StateMachineManager.SessionMessage
// TODO revisit when Kotlin 1.1 is released and data classes can extend other classes
interface FlowIORequest {

View File

@ -7,15 +7,16 @@ import co.paralleluniverse.strands.Strand
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import net.corda.core.crypto.Party
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSessionException
import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId
import net.corda.core.random63BitValue
import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.trace
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.statemachine.StateMachineManager.*
import net.corda.node.services.statemachine.StateMachineManager.FlowSession
import net.corda.node.services.statemachine.StateMachineManager.FlowSessionState
import net.corda.node.utilities.StrandLocalTransactionManager
import net.corda.node.utilities.createDatabaseTransaction
import net.corda.node.utilities.databaseTransaction
@ -31,7 +32,6 @@ import java.util.concurrent.ExecutionException
class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val logic: FlowLogic<R>,
scheduler: FiberScheduler) : Fiber<R>("flow", scheduler), FlowStateMachine<R> {
companion object {
// Used to work around a small limitation in Quasar.
private val QUASAR_UNBLOCKER = run {
@ -76,7 +76,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
internal val openSessions = HashMap<Pair<FlowLogic<*>, Party>, FlowSession>()
init {
logic.fsm = this
logic.stateMachine = this
name = id.toString()
}
@ -120,9 +120,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
}
@Suspendable
override fun <T : Any> sendAndReceive(otherParty: Party,
override fun <T : Any> sendAndReceive(receiveType: Class<T>,
otherParty: Party,
payload: Any,
receiveType: Class<T>,
sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
val (session, new) = getSession(otherParty, sessionFlow, payload)
val receivedSessionData = if (new) {
@ -132,16 +132,15 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val sendSessionData = createSessionData(session, payload)
sendAndReceiveInternal<SessionData>(session, sendSessionData)
}
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
return receivedSessionData.checkPayloadIs(receiveType)
}
@Suspendable
override fun <T : Any> receive(otherParty: Party,
receiveType: Class<T>,
override fun <T : Any> receive(receiveType: Class<T>,
otherParty: Party,
sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
val session = getSession(otherParty, sessionFlow, null).first
val receivedSessionData = receiveInternal<SessionData>(session)
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
return receiveInternal<SessionData>(session).checkPayloadIs(receiveType)
}
@Suspendable
@ -156,8 +155,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
private fun createSessionData(session: FlowSession, payload: Any): SessionData {
val sessionState = session.state
val peerSessionId = when (sessionState) {
is StateMachineManager.FlowSessionState.Initiating -> throw IllegalStateException("We've somehow held onto an unconfirmed session: $session")
is StateMachineManager.FlowSessionState.Initiated -> sessionState.peerSessionId
is FlowSessionState.Initiating -> throw IllegalStateException("We've somehow held onto an unconfirmed session: $session")
is FlowSessionState.Initiated -> sessionState.peerSessionId
}
return SessionData(peerSessionId, payload)
}
@ -167,16 +166,13 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
suspend(SendOnly(session, message))
}
@Suspendable
private inline fun <reified M : SessionMessage> receiveInternal(session: FlowSession): M {
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)).message
private inline fun <reified M : ExistingSessionMessage> receiveInternal(session: FlowSession): ReceivedSessionMessage<M> {
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java))
}
private inline fun <reified M : SessionMessage> sendAndReceiveInternal(session: FlowSession, message: SessionMessage): M {
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)).message
}
private inline fun <reified M : SessionMessage> sendAndReceiveInternalWithParty(session: FlowSession, message: SessionMessage): ReceivedSessionMessage<M> {
private inline fun <reified M : ExistingSessionMessage> sendAndReceiveInternal(
session: FlowSession,
message: SessionMessage): ReceivedSessionMessage<M> {
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java))
}
@ -203,20 +199,21 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
openSessions[Pair(sessionFlow, otherParty)] = session
val counterpartyFlow = sessionFlow.getCounterpartyMarker(otherParty).name
val sessionInit = SessionInit(session.ourSessionId, counterpartyFlow, firstPayload)
val (peerParty, sessionInitResponse) = sendAndReceiveInternalWithParty<SessionInitResponse>(session, sessionInit)
val (peerParty, sessionInitResponse) = sendAndReceiveInternal<SessionInitResponse>(session, sessionInit)
if (sessionInitResponse is SessionConfirm) {
require(session.state is FlowSessionState.Initiating)
session.state = FlowSessionState.Initiated(peerParty, sessionInitResponse.initiatedSessionId)
return session
} else {
sessionInitResponse as SessionReject
throw FlowSessionException("Party $otherParty rejected session attempt: ${sessionInitResponse.errorMessage}")
throw FlowException("Party $otherParty rejected session request: ${sessionInitResponse.errorMessage}")
}
}
@Suspendable
private fun <M : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
fun getReceivedMessage(): ReceivedSessionMessage<ExistingSessionMessage>? = receiveRequest.session.receivedMessages.poll()
private fun <M : ExistingSessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
val session = receiveRequest.session
fun getReceivedMessage(): ReceivedSessionMessage<ExistingSessionMessage>? = session.receivedMessages.poll()
val polledMessage = getReceivedMessage()
val receivedMessage = if (polledMessage != null) {
@ -228,17 +225,21 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} else {
// Suspend while we wait for a receive
suspend(receiveRequest)
getReceivedMessage()
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest")
getReceivedMessage() ?:
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead " +
"got nothing: $receiveRequest")
}
if (receivedMessage.message is SessionEnd) {
openSessions.values.remove(receiveRequest.session)
throw FlowSessionException("Counterparty on ${receiveRequest.session.state.sendToParty} has prematurely ended on $receiveRequest")
openSessions.values.remove(session)
throw FlowException("Party ${session.state.sendToParty} has ended their flow but we were expecting to " +
"receive ${receiveRequest.receiveType.simpleName} from them")
} else if (receiveRequest.receiveType.isInstance(receivedMessage.message)) {
return ReceivedSessionMessage(receivedMessage.sendingParty, receiveRequest.receiveType.cast(receivedMessage.message))
@Suppress("UNCHECKED_CAST")
return receivedMessage as ReceivedSessionMessage<M>
} else {
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $receiveRequest")
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead got " +
"${receivedMessage.message}: $receiveRequest")
}
}
@ -292,5 +293,4 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
logger.error("Error during resume", t)
}
}
}

View File

@ -0,0 +1,43 @@
package net.corda.node.services.statemachine
import net.corda.core.abbreviate
import net.corda.core.crypto.Party
import net.corda.core.flows.FlowException
import net.corda.core.utilities.UntrustworthyData
interface SessionMessage
interface ExistingSessionMessage : SessionMessage {
val recipientSessionId: Long
}
data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage
interface SessionInitResponse : ExistingSessionMessage
data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId
}
data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId
}
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage {
override fun toString(): String {
return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})"
}
}
data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage
data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M)
fun <T> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
if (type.isInstance(message.payload)) {
return UntrustworthyData(type.cast(message.payload))
} else {
throw FlowException("We were expecting a ${type.name} from $sender but we instead got a " +
"${message.payload.javaClass.name} (${message.payload})")
}
}

View File

@ -9,15 +9,19 @@ import com.esotericsoftware.kryo.Kryo
import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ListenableFuture
import kotlinx.support.jdk8.collections.removeIf
import net.corda.core.*
import net.corda.core.ThreadBox
import net.corda.core.bufferUntilSubscribed
import net.corda.core.crypto.Party
import net.corda.core.crypto.commonName
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId
import net.corda.core.messaging.ReceivedMessage
import net.corda.core.messaging.TopicSession
import net.corda.core.messaging.send
import net.corda.core.random63BitValue
import net.corda.core.serialization.*
import net.corda.core.then
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.debug
import net.corda.core.utilities.loggerFor
@ -89,8 +93,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val stateMachines = LinkedHashMap<FlowStateMachineImpl<*>, Checkpoint>()
val changesPublisher = PublishSubject.create<Change>()
fun notifyChangeObservers(psm: FlowStateMachineImpl<*>, addOrRemove: AddOrRemove) {
changesPublisher.bufferUntilDatabaseCommit().onNext(Change(psm.logic, addOrRemove, psm.id))
fun notifyChangeObservers(fiber: FlowStateMachineImpl<*>, addOrRemove: AddOrRemove) {
changesPublisher.bufferUntilDatabaseCommit().onNext(Change(fiber.logic, addOrRemove, fiber.id))
}
})
@ -125,7 +129,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
stateMachines.keys
.map { it.logic }
.filterIsInstance(flowClass)
.map { it to (it.fsm as FlowStateMachineImpl<T>).resultFuture }
.map { it to (it.stateMachine as FlowStateMachineImpl<T>).resultFuture }
}
}
@ -207,17 +211,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
serviceHub.networkService.addMessageHandler(sessionTopic) { message, reg ->
executor.checkOnThread()
val sessionMessage = message.data.deserialize<SessionMessage>()
// TODO Look up the party with the full X.500 name instead of just the legal name
val otherParty = serviceHub.networkMapCache.getNodeByLegalName(message.peer.commonName)?.legalIdentity
if (otherParty != null) {
when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, otherParty)
is SessionInit -> onSessionInit(sessionMessage, otherParty)
}
} else {
logger.error("Unknown peer ${message.peer} in $sessionMessage")
}
onSessionMessage(message)
}
}
@ -230,26 +224,40 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
}
private fun onExistingSessionMessage(message: ExistingSessionMessage, otherParty: Party) {
private fun onSessionMessage(message: ReceivedMessage) {
val sessionMessage = message.data.deserialize<SessionMessage>()
// TODO Look up the party with the full X.500 name instead of just the legal name
val sender = serviceHub.networkMapCache.getNodeByLegalName(message.peer.commonName)?.legalIdentity
if (sender != null) {
when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender)
is SessionInit -> onSessionInit(sessionMessage, sender)
}
} else {
logger.error("Unknown peer ${message.peer} in $sessionMessage")
}
}
private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) {
val session = openSessions[message.recipientSessionId]
if (session != null) {
session.psm.logger.trace { "Received $message on $session" }
session.fiber.logger.trace { "Received $message on $session" }
if (message is SessionEnd) {
openSessions.remove(message.recipientSessionId)
}
session.receivedMessages += ReceivedSessionMessage(otherParty, message)
session.receivedMessages += ReceivedSessionMessage(sender, message)
if (session.waitingForResponse) {
// We only want to resume once, so immediately reset the flag.
session.waitingForResponse = false
updateCheckpoint(session.psm)
resumeFiber(session.psm)
updateCheckpoint(session.fiber)
resumeFiber(session.fiber)
}
} else {
val peerParty = recentlyClosedSessions.remove(message.recipientSessionId)
if (peerParty != null) {
if (message is SessionConfirm) {
logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" }
sendSessionMessage(peerParty, SessionEnd(message.initiatedSessionId), null)
sendSessionMessage(peerParty, SessionEnd(message.initiatedSessionId))
} else {
logger.trace { "Ignoring session end message for already closed session: $message" }
}
@ -259,32 +267,32 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
}
private fun onSessionInit(sessionInit: SessionInit, otherParty: Party) {
logger.trace { "Received $sessionInit $otherParty" }
private fun onSessionInit(sessionInit: SessionInit, sender: Party) {
logger.trace { "Received $sessionInit $sender" }
val otherPartySessionId = sessionInit.initiatorSessionId
try {
val markerClass = Class.forName(sessionInit.flowName)
val flowFactory = serviceHub.getFlowFactory(markerClass)
if (flowFactory != null) {
val flow = flowFactory(otherParty)
val psm = createFiber(flow)
val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(otherParty, otherPartySessionId))
val flow = flowFactory(sender)
val fiber = createFiber(flow)
val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(sender, otherPartySessionId))
if (sessionInit.firstPayload != null) {
session.receivedMessages += ReceivedSessionMessage(otherParty, SessionData(session.ourSessionId, sessionInit.firstPayload))
session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload))
}
openSessions[session.ourSessionId] = session
psm.openSessions[Pair(flow, otherParty)] = session
updateCheckpoint(psm)
sendSessionMessage(otherParty, SessionConfirm(otherPartySessionId, session.ourSessionId), psm)
psm.logger.debug { "Initiated from $sessionInit on $session" }
startFiber(psm)
fiber.openSessions[Pair(flow, sender)] = session
updateCheckpoint(fiber)
sendSessionMessage(sender, SessionConfirm(otherPartySessionId, session.ourSessionId), fiber)
fiber.logger.debug { "Initiated from $sessionInit on $session" }
startFiber(fiber)
} else {
logger.warn("Unknown flow marker class in $sessionInit")
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}"), null)
sendSessionMessage(sender, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}"))
}
} catch (e: Exception) {
logger.warn("Received invalid $sessionInit", e)
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Unable to establish session"), null)
sendSessionMessage(sender, SessionReject(otherPartySessionId, "Unable to establish session"))
}
}
@ -312,27 +320,27 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
return FlowStateMachineImpl(id, logic, scheduler).apply { initFiber(this) }
}
private fun initFiber(psm: FlowStateMachineImpl<*>) {
psm.database = database
psm.serviceHub = serviceHub
psm.actionOnSuspend = { ioRequest ->
updateCheckpoint(psm)
private fun initFiber(fiber: FlowStateMachineImpl<*>) {
fiber.database = database
fiber.serviceHub = serviceHub
fiber.actionOnSuspend = { ioRequest ->
updateCheckpoint(fiber)
// We commit on the fibers transaction that was copied across ThreadLocals during suspend
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
psm.commitTransaction()
fiber.commitTransaction()
processIORequest(ioRequest)
decrementLiveFibers()
}
psm.actionOnEnd = {
fiber.actionOnEnd = {
try {
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
fiber.logic.progressTracker?.currentStep = ProgressTracker.DONE
mutex.locked {
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) }
totalFinishedFlows.inc()
unfinishedFibers.countDown()
notifyChangeObservers(psm, AddOrRemove.REMOVE)
notifyChangeObservers(fiber, AddOrRemove.REMOVE)
}
endAllFiberSessions(psm)
endAllFiberSessions(fiber)
} finally {
decrementLiveFibers()
}
@ -340,16 +348,16 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
mutex.locked {
totalStartedFlows.inc()
unfinishedFibers.countUp()
notifyChangeObservers(psm, AddOrRemove.ADD)
notifyChangeObservers(fiber, AddOrRemove.ADD)
}
}
private fun endAllFiberSessions(psm: FlowStateMachineImpl<*>) {
private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>) {
openSessions.values.removeIf { session ->
if (session.psm == psm) {
if (session.fiber == fiber) {
val initiatedState = session.state as? FlowSessionState.Initiated
if (initiatedState != null) {
sendSessionMessage(initiatedState.peerParty, SessionEnd(initiatedState.peerSessionId), psm)
sendSessionMessage(initiatedState.peerParty, SessionEnd(initiatedState.peerSessionId), fiber)
recentlyClosedSessions[session.ourSessionId] = initiatedState.peerParty
}
true
@ -405,10 +413,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
return fiber
}
private fun updateCheckpoint(psm: FlowStateMachineImpl<*>) {
check(psm.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" }
val newCheckpoint = Checkpoint(serializeFiber(psm))
val previousCheckpoint = mutex.locked { stateMachines.put(psm, newCheckpoint) }
private fun updateCheckpoint(fiber: FlowStateMachineImpl<*>) {
check(fiber.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" }
val newCheckpoint = Checkpoint(serializeFiber(fiber))
val previousCheckpoint = mutex.locked { stateMachines.put(fiber, newCheckpoint) }
if (previousCheckpoint != null) {
checkpointStorage.removeCheckpoint(previousCheckpoint)
}
@ -416,13 +424,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
checkpointingMeter.mark()
}
private fun resumeFiber(psm: FlowStateMachineImpl<*>) {
private fun resumeFiber(fiber: FlowStateMachineImpl<*>) {
// Avoid race condition when setting stopping to true and then checking liveFibers
incrementLiveFibers()
if (!stopping) executor.executeASAP {
psm.resume(scheduler)
fiber.resume(scheduler)
} else {
psm.logger.debug("Not resuming as SMM is stopping.")
fiber.logger.debug("Not resuming as SMM is stopping.")
decrementLiveFibers()
}
}
@ -432,51 +440,23 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
if (ioRequest.message is SessionInit) {
openSessions[ioRequest.session.ourSessionId] = ioRequest.session
}
sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.psm)
sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber)
if (ioRequest !is ReceiveRequest<*>) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
resumeFiber(ioRequest.session.psm)
resumeFiber(ioRequest.session.fiber)
}
}
}
private fun sendSessionMessage(party: Party, message: SessionMessage, psm: FlowStateMachineImpl<*>?) {
private fun sendSessionMessage(party: Party, message: SessionMessage, fiber: FlowStateMachineImpl<*>? = null) {
val partyInfo = serviceHub.networkMapCache.getPartyInfo(party)
?: throw IllegalArgumentException("Don't know about party $party")
val address = serviceHub.networkService.getAddressOfParty(partyInfo)
val logger = psm?.logger ?: logger
val logger = fiber?.logger ?: logger
logger.debug { "Sending $message to party $party, address: $address" }
serviceHub.networkService.send(sessionTopic, message, address)
}
data class ReceivedSessionMessage<out M : SessionMessage>(val sendingParty: Party, val message: M)
interface SessionMessage
interface ExistingSessionMessage : SessionMessage {
val recipientSessionId: Long
}
data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage
interface SessionInitResponse : ExistingSessionMessage
data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId
}
data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId
}
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage {
override fun toString(): String {
return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})"
}
}
data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage
/**
* [FlowSessionState] describes the session's state.
*
@ -507,7 +487,6 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
@Volatile var waitingForResponse: Boolean = false
) {
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<ExistingSessionMessage>>()
val psm: FlowStateMachineImpl<*> get() = flow.fsm as FlowStateMachineImpl<*>
val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*>
}
}

View File

@ -415,10 +415,10 @@ class TwoPartyTradeFlowTests {
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>): RunResult {
val buyerFuture = bobNode.initiateSingleShotFlow(Seller::class) { otherParty ->
Buyer(otherParty, notaryNode.info.notaryIdentity, 1000.DOLLARS, CommercialPaper.State::class.java)
}.map { it.fsm }
}.map { it.stateMachine }
val seller = Seller(bobNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
val sellerResultFuture = aliceNode.services.startFlow(seller).resultFuture
return RunResult(buyerFuture, sellerResultFuture, seller.fsm.id)
return RunResult(buyerFuture, sellerResultFuture, seller.stateMachine.id)
}
private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError(

View File

@ -7,8 +7,8 @@ import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.issuedBy
import net.corda.core.crypto.Party
import net.corda.core.crypto.generateKeyPair
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSessionException
import net.corda.core.getOrThrow
import net.corda.core.random63BitValue
import net.corda.core.serialization.OpaqueBytes
@ -17,7 +17,6 @@ import net.corda.flows.CashCommand
import net.corda.flows.CashFlow
import net.corda.flows.NotaryFlow
import net.corda.node.services.persistence.checkpoints
import net.corda.node.services.statemachine.StateMachineManager.*
import net.corda.node.utilities.databaseTransaction
import net.corda.testing.expect
import net.corda.testing.expectEvents
@ -215,15 +214,15 @@ class StateMachineManagerTests {
assertSessionTransfers(node2,
node1 sent sessionInit(SendFlow::class, payload) to node2,
node2 sent sessionConfirm() to node1,
node1 sent sessionEnd() to node2
node2 sent sessionConfirm to node1,
node1 sent sessionEnd to node2
//There's no session end from the other flows as they're manually suspended
)
assertSessionTransfers(node3,
node1 sent sessionInit(SendFlow::class, payload) to node3,
node3 sent sessionConfirm() to node1,
node1 sent sessionEnd() to node3
node3 sent sessionConfirm to node1,
node1 sent sessionEnd to node3
//There's no session end from the other flows as they're manually suspended
)
@ -248,16 +247,16 @@ class StateMachineManagerTests {
assertSessionTransfers(node2,
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
node2 sent sessionConfirm() to node1,
node2 sent sessionConfirm to node1,
node2 sent sessionData(node2Payload) to node1,
node2 sent sessionEnd() to node1
node2 sent sessionEnd to node1
)
assertSessionTransfers(node3,
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node3,
node3 sent sessionConfirm() to node1,
node3 sent sessionConfirm to node1,
node3 sent sessionData(node3Payload) to node1,
node3 sent sessionEnd() to node1
node3 sent sessionEnd to node1
)
}
@ -269,11 +268,11 @@ class StateMachineManagerTests {
assertSessionTransfers(
node1 sent sessionInit(PingPongFlow::class, 10L) to node2,
node2 sent sessionConfirm() to node1,
node2 sent sessionConfirm to node1,
node2 sent sessionData(20L) to node1,
node1 sent sessionData(11L) to node2,
node2 sent sessionData(21L) to node1,
node1 sent sessionEnd() to node2
node1 sent sessionEnd to node2
)
}
@ -333,11 +332,11 @@ class StateMachineManagerTests {
node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { ExceptionFlow }
val future = node1.services.startFlow(ReceiveThenSuspendFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowSessionException::class.java)
assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowException::class.java)
assertSessionTransfers(
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
node2 sent sessionConfirm() to node1,
node2 sent sessionEnd() to node1
node2 sent sessionConfirm to node1,
node2 sent sessionEnd to node1
)
}
@ -358,11 +357,11 @@ class StateMachineManagerTests {
private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload)
private fun sessionConfirm() = SessionConfirm(0, 0)
private val sessionConfirm = SessionConfirm(0, 0)
private fun sessionData(payload: Any) = SessionData(0, payload)
private fun sessionEnd() = SessionEnd(0)
private val sessionEnd = SessionEnd(0)
private fun assertSessionTransfers(vararg expected: SessionTransfer) {
assertThat(sessionTransfers).containsExactly(*expected)
@ -462,5 +461,4 @@ class StateMachineManagerTests {
private object ExceptionFlow : FlowLogic<Nothing>() {
override fun call(): Nothing = throw Exception()
}
}

View File

@ -120,7 +120,7 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten
@Suppress("UNCHECKED_CAST")
val acceptorTx = node2.initiateSingleShotFlow(Instigator::class) { Acceptor(it) }.flatMap {
(it.fsm as FlowStateMachine<SignedTransaction>).resultFuture
(it.stateMachine as FlowStateMachine<SignedTransaction>).resultFuture
}
showProgressFor(listOf(node1, node2))

View File

@ -53,7 +53,7 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo
@Suppress("UNCHECKED_CAST")
val buyerFuture = buyer.initiateSingleShotFlow(Seller::class) {
Buyer(it, notary.info.notaryIdentity, amount, CommercialPaper.State::class.java)
}.flatMap { (it.fsm as FlowStateMachine<SignedTransaction>).resultFuture }
}.flatMap { (it.stateMachine as FlowStateMachine<SignedTransaction>).resultFuture }
val sellerKey = seller.services.legalIdentityKey
val sellerFlow = Seller(

View File

@ -17,7 +17,9 @@ import net.corda.core.then
import net.corda.core.utilities.ProgressTracker
import net.corda.netmap.VisualiserViewModel.Style
import net.corda.node.services.network.NetworkMapService
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.node.services.statemachine.SessionConfirm
import net.corda.node.services.statemachine.SessionEnd
import net.corda.node.services.statemachine.SessionInit
import net.corda.simulation.IRSSimulation
import net.corda.simulation.Simulation
import net.corda.testing.node.InMemoryMessagingNetwork
@ -349,13 +351,12 @@ class NetworkMapVisualiser : Application() {
// Network map push acknowledgements are boring.
if (NetworkMapService.PUSH_ACK_FLOW_TOPIC in transfer.message.topicSession.topic) return false
val message = transfer.message.data.deserialize<Any>()
val messageClassType = message.javaClass.name
when (messageClassType) {
StateMachineManager.SessionEnd::class.java.name -> return false
StateMachineManager.SessionConfirm::class.java.name -> return false
StateMachineManager.SessionInit::class.java.name -> if ((message as StateMachineManager.SessionInit).firstPayload == null) return false
return when (message) {
is SessionEnd -> false
is SessionConfirm -> false
is SessionInit -> message.firstPayload != null
else -> true
}
return true
}
}