From e589031d4bb7a29b6d264b8e5829afcf03bb0be0 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Wed, 11 Jan 2017 10:21:54 +0000 Subject: [PATCH] Some clean up of the flow code --- .../kotlin/net/corda/core/flows/FlowLogic.kt | 17 +- .../net/corda/core/flows/FlowStateMachine.kt | 8 +- .../main/kotlin/net/corda/flows/CashFlow.kt | 6 +- .../kotlin/net/corda/flows/IssuerFlowTest.kt | 2 +- ...{ProtocolIORequest.kt => FlowIORequest.kt} | 1 - .../statemachine/FlowStateMachineImpl.kt | 64 +++---- .../services/statemachine/SessionMessage.kt | 43 +++++ .../statemachine/StateMachineManager.kt | 159 ++++++++---------- .../messaging/TwoPartyTradeProtocolTests.kt | 4 +- .../statemachine/StateMachineManagerTests.kt | 34 ++-- .../net/corda/simulation/IRSSimulation.kt | 2 +- .../net/corda/simulation/TradeSimulation.kt | 2 +- .../net/corda/netmap/NetworkMapVisualiser.kt | 15 +- 13 files changed, 188 insertions(+), 169 deletions(-) rename node/src/main/kotlin/net/corda/node/services/statemachine/{ProtocolIORequest.kt => FlowIORequest.kt} (95%) create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index bb759404f7..90bc4499b6 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -27,18 +27,18 @@ import rx.Observable */ abstract class FlowLogic { - /** Reference to the [Fiber] instance that is the top level controller for the entire flow. */ - lateinit var fsm: FlowStateMachine<*> + /** Reference to the [FlowStateMachine] instance that is the top level controller for the entire flow. */ + lateinit var stateMachine: FlowStateMachine<*> /** This is where you should log things to. */ - val logger: Logger get() = fsm.logger + val logger: Logger get() = stateMachine.logger /** * Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is * only available once the flow has started, which means it cannnot be accessed in the constructor. Either * access this lazily or from inside [call]. */ - val serviceHub: ServiceHub get() = fsm.serviceHub + val serviceHub: ServiceHub get() = stateMachine.serviceHub private var sessionFlow: FlowLogic<*> = this @@ -56,19 +56,19 @@ abstract class FlowLogic { @Suspendable fun sendAndReceive(receiveType: Class, otherParty: Party, payload: Any): UntrustworthyData { - return fsm.sendAndReceive(otherParty, payload, receiveType, sessionFlow) + return stateMachine.sendAndReceive(receiveType, otherParty, payload, sessionFlow) } inline fun receive(otherParty: Party): UntrustworthyData = receive(T::class.java, otherParty) @Suspendable fun receive(receiveType: Class, otherParty: Party): UntrustworthyData { - return fsm.receive(otherParty, receiveType, sessionFlow) + return stateMachine.receive(receiveType, otherParty, sessionFlow) } @Suspendable fun send(otherParty: Party, payload: Any) { - fsm.send(otherParty, payload, sessionFlow) + stateMachine.send(otherParty, payload, sessionFlow) } /** @@ -82,7 +82,7 @@ abstract class FlowLogic { // TODO shareParentSessions is a bit too low-level and perhaps can be expresed in a better way @Suspendable fun subFlow(subLogic: FlowLogic, shareParentSessions: Boolean = false): R { - subLogic.fsm = fsm + subLogic.stateMachine = stateMachine maybeWireUpProgressTracking(subLogic) if (shareParentSessions) { subLogic.sessionFlow = this @@ -127,5 +127,4 @@ abstract class FlowLogic { Pair(it.currentStep.toString(), it.changes.map { it.toString() }) } } - } diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowStateMachine.kt b/core/src/main/kotlin/net/corda/core/flows/FlowStateMachine.kt index f8e6d8c9b4..f6b691c5c5 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowStateMachine.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowStateMachine.kt @@ -28,13 +28,13 @@ data class StateMachineRunId private constructor(val uuid: UUID) { */ interface FlowStateMachine { @Suspendable - fun sendAndReceive(otherParty: Party, + fun sendAndReceive(receiveType: Class, + otherParty: Party, payload: Any, - receiveType: Class, sessionFlow: FlowLogic<*>): UntrustworthyData @Suspendable - fun receive(otherParty: Party, receiveType: Class, sessionFlow: FlowLogic<*>): UntrustworthyData + fun receive(receiveType: Class, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData @Suspendable fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) @@ -48,4 +48,4 @@ interface FlowStateMachine { val resultFuture: ListenableFuture } -class FlowSessionException(message: String) : Exception(message) +class FlowException(message: String) : RuntimeException(message) diff --git a/finance/src/main/kotlin/net/corda/flows/CashFlow.kt b/finance/src/main/kotlin/net/corda/flows/CashFlow.kt index d3683b9ff6..1b30781814 100644 --- a/finance/src/main/kotlin/net/corda/flows/CashFlow.kt +++ b/finance/src/main/kotlin/net/corda/flows/CashFlow.kt @@ -59,7 +59,7 @@ class CashFlow(val command: CashCommand, override val progressTracker: ProgressT val flow = FinalityFlow(tx, setOf(req.recipient)) subFlow(flow) return CashFlowResult.Success( - fsm.id, + stateMachine.id, tx, "Cash payment transaction generated" ) @@ -95,7 +95,7 @@ class CashFlow(val command: CashCommand, override val progressTracker: ProgressT val tx = builder.toSignedTransaction(checkSufficientSignatures = false) subFlow(FinalityFlow(tx, participants)) return CashFlowResult.Success( - fsm.id, + stateMachine.id, tx, "Cash destruction transaction generated" ) @@ -116,7 +116,7 @@ class CashFlow(val command: CashCommand, override val progressTracker: ProgressT // Issuance transactions do not need to be notarised, so we can skip directly to broadcasting it subFlow(BroadcastTransactionFlow(tx, setOf(req.recipient))) return CashFlowResult.Success( - fsm.id, + stateMachine.id, tx, "Cash issuance completed" ) diff --git a/finance/src/test/kotlin/net/corda/flows/IssuerFlowTest.kt b/finance/src/test/kotlin/net/corda/flows/IssuerFlowTest.kt index 5ee35635cc..92efac55b3 100644 --- a/finance/src/test/kotlin/net/corda/flows/IssuerFlowTest.kt +++ b/finance/src/test/kotlin/net/corda/flows/IssuerFlowTest.kt @@ -54,7 +54,7 @@ class IssuerFlowTest { private fun runIssuerAndIssueRequester(amount: Amount, issueToPartyAndRef: PartyAndReference) : RunResult { val issuerFuture = bankOfCordaNode.initiateSingleShotFlow(IssuerFlow.IssuanceRequester::class) { otherParty -> IssuerFlow.Issuer(issueToPartyAndRef.party) - }.map { it.fsm } + }.map { it.stateMachine } val issueRequest = IssuanceRequester(amount, issueToPartyAndRef.party, issueToPartyAndRef.reference, bankOfCordaNode.info.legalIdentity) val issueRequestResultFuture = bankClientNode.services.startFlow(issueRequest).resultFuture diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ProtocolIORequest.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt similarity index 95% rename from node/src/main/kotlin/net/corda/node/services/statemachine/ProtocolIORequest.kt rename to node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt index 42bd7c579d..10a1c72e16 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/ProtocolIORequest.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt @@ -1,7 +1,6 @@ package net.corda.node.services.statemachine import net.corda.node.services.statemachine.StateMachineManager.FlowSession -import net.corda.node.services.statemachine.StateMachineManager.SessionMessage // TODO revisit when Kotlin 1.1 is released and data classes can extend other classes interface FlowIORequest { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index e2487cfcc1..9b71110da3 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -7,15 +7,16 @@ import co.paralleluniverse.strands.Strand import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.SettableFuture import net.corda.core.crypto.Party +import net.corda.core.flows.FlowException import net.corda.core.flows.FlowLogic -import net.corda.core.flows.FlowSessionException import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.StateMachineRunId import net.corda.core.random63BitValue import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.trace import net.corda.node.services.api.ServiceHubInternal -import net.corda.node.services.statemachine.StateMachineManager.* +import net.corda.node.services.statemachine.StateMachineManager.FlowSession +import net.corda.node.services.statemachine.StateMachineManager.FlowSessionState import net.corda.node.utilities.StrandLocalTransactionManager import net.corda.node.utilities.createDatabaseTransaction import net.corda.node.utilities.databaseTransaction @@ -31,7 +32,6 @@ import java.util.concurrent.ExecutionException class FlowStateMachineImpl(override val id: StateMachineRunId, val logic: FlowLogic, scheduler: FiberScheduler) : Fiber("flow", scheduler), FlowStateMachine { - companion object { // Used to work around a small limitation in Quasar. private val QUASAR_UNBLOCKER = run { @@ -76,7 +76,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, internal val openSessions = HashMap, Party>, FlowSession>() init { - logic.fsm = this + logic.stateMachine = this name = id.toString() } @@ -120,9 +120,9 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - override fun sendAndReceive(otherParty: Party, + override fun sendAndReceive(receiveType: Class, + otherParty: Party, payload: Any, - receiveType: Class, sessionFlow: FlowLogic<*>): UntrustworthyData { val (session, new) = getSession(otherParty, sessionFlow, payload) val receivedSessionData = if (new) { @@ -132,16 +132,15 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, val sendSessionData = createSessionData(session, payload) sendAndReceiveInternal(session, sendSessionData) } - return UntrustworthyData(receiveType.cast(receivedSessionData.payload)) + return receivedSessionData.checkPayloadIs(receiveType) } @Suspendable - override fun receive(otherParty: Party, - receiveType: Class, + override fun receive(receiveType: Class, + otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData { val session = getSession(otherParty, sessionFlow, null).first - val receivedSessionData = receiveInternal(session) - return UntrustworthyData(receiveType.cast(receivedSessionData.payload)) + return receiveInternal(session).checkPayloadIs(receiveType) } @Suspendable @@ -156,8 +155,8 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, private fun createSessionData(session: FlowSession, payload: Any): SessionData { val sessionState = session.state val peerSessionId = when (sessionState) { - is StateMachineManager.FlowSessionState.Initiating -> throw IllegalStateException("We've somehow held onto an unconfirmed session: $session") - is StateMachineManager.FlowSessionState.Initiated -> sessionState.peerSessionId + is FlowSessionState.Initiating -> throw IllegalStateException("We've somehow held onto an unconfirmed session: $session") + is FlowSessionState.Initiated -> sessionState.peerSessionId } return SessionData(peerSessionId, payload) } @@ -167,16 +166,13 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, suspend(SendOnly(session, message)) } - @Suspendable - private inline fun receiveInternal(session: FlowSession): M { - return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)).message + private inline fun receiveInternal(session: FlowSession): ReceivedSessionMessage { + return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)) } - private inline fun sendAndReceiveInternal(session: FlowSession, message: SessionMessage): M { - return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)).message - } - - private inline fun sendAndReceiveInternalWithParty(session: FlowSession, message: SessionMessage): ReceivedSessionMessage { + private inline fun sendAndReceiveInternal( + session: FlowSession, + message: SessionMessage): ReceivedSessionMessage { return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)) } @@ -203,20 +199,21 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, openSessions[Pair(sessionFlow, otherParty)] = session val counterpartyFlow = sessionFlow.getCounterpartyMarker(otherParty).name val sessionInit = SessionInit(session.ourSessionId, counterpartyFlow, firstPayload) - val (peerParty, sessionInitResponse) = sendAndReceiveInternalWithParty(session, sessionInit) + val (peerParty, sessionInitResponse) = sendAndReceiveInternal(session, sessionInit) if (sessionInitResponse is SessionConfirm) { require(session.state is FlowSessionState.Initiating) session.state = FlowSessionState.Initiated(peerParty, sessionInitResponse.initiatedSessionId) return session } else { sessionInitResponse as SessionReject - throw FlowSessionException("Party $otherParty rejected session attempt: ${sessionInitResponse.errorMessage}") + throw FlowException("Party $otherParty rejected session request: ${sessionInitResponse.errorMessage}") } } @Suspendable - private fun suspendAndExpectReceive(receiveRequest: ReceiveRequest): ReceivedSessionMessage { - fun getReceivedMessage(): ReceivedSessionMessage? = receiveRequest.session.receivedMessages.poll() + private fun suspendAndExpectReceive(receiveRequest: ReceiveRequest): ReceivedSessionMessage { + val session = receiveRequest.session + fun getReceivedMessage(): ReceivedSessionMessage? = session.receivedMessages.poll() val polledMessage = getReceivedMessage() val receivedMessage = if (polledMessage != null) { @@ -228,17 +225,21 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } else { // Suspend while we wait for a receive suspend(receiveRequest) - getReceivedMessage() - ?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest") + getReceivedMessage() ?: + throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead " + + "got nothing: $receiveRequest") } if (receivedMessage.message is SessionEnd) { - openSessions.values.remove(receiveRequest.session) - throw FlowSessionException("Counterparty on ${receiveRequest.session.state.sendToParty} has prematurely ended on $receiveRequest") + openSessions.values.remove(session) + throw FlowException("Party ${session.state.sendToParty} has ended their flow but we were expecting to " + + "receive ${receiveRequest.receiveType.simpleName} from them") } else if (receiveRequest.receiveType.isInstance(receivedMessage.message)) { - return ReceivedSessionMessage(receivedMessage.sendingParty, receiveRequest.receiveType.cast(receivedMessage.message)) + @Suppress("UNCHECKED_CAST") + return receivedMessage as ReceivedSessionMessage } else { - throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $receiveRequest") + throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead got " + + "${receivedMessage.message}: $receiveRequest") } } @@ -292,5 +293,4 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, logger.error("Error during resume", t) } } - } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt new file mode 100644 index 0000000000..6602e55add --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt @@ -0,0 +1,43 @@ +package net.corda.node.services.statemachine + +import net.corda.core.abbreviate +import net.corda.core.crypto.Party +import net.corda.core.flows.FlowException +import net.corda.core.utilities.UntrustworthyData + +interface SessionMessage + +interface ExistingSessionMessage : SessionMessage { + val recipientSessionId: Long +} + +data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage + +interface SessionInitResponse : ExistingSessionMessage + +data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse { + override val recipientSessionId: Long get() = initiatorSessionId +} + +data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse { + override val recipientSessionId: Long get() = initiatorSessionId +} + +data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage { + override fun toString(): String { + return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})" + } +} + +data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage + +data class ReceivedSessionMessage(val sender: Party, val message: M) + +fun ReceivedSessionMessage.checkPayloadIs(type: Class): UntrustworthyData { + if (type.isInstance(message.payload)) { + return UntrustworthyData(type.cast(message.payload)) + } else { + throw FlowException("We were expecting a ${type.name} from $sender but we instead got a " + + "${message.payload.javaClass.name} (${message.payload})") + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 44e37250fe..cb0a0bef6e 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -9,15 +9,19 @@ import com.esotericsoftware.kryo.Kryo import com.google.common.annotations.VisibleForTesting import com.google.common.util.concurrent.ListenableFuture import kotlinx.support.jdk8.collections.removeIf -import net.corda.core.* +import net.corda.core.ThreadBox +import net.corda.core.bufferUntilSubscribed import net.corda.core.crypto.Party import net.corda.core.crypto.commonName import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.StateMachineRunId +import net.corda.core.messaging.ReceivedMessage import net.corda.core.messaging.TopicSession import net.corda.core.messaging.send +import net.corda.core.random63BitValue import net.corda.core.serialization.* +import net.corda.core.then import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.debug import net.corda.core.utilities.loggerFor @@ -89,8 +93,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, val stateMachines = LinkedHashMap, Checkpoint>() val changesPublisher = PublishSubject.create() - fun notifyChangeObservers(psm: FlowStateMachineImpl<*>, addOrRemove: AddOrRemove) { - changesPublisher.bufferUntilDatabaseCommit().onNext(Change(psm.logic, addOrRemove, psm.id)) + fun notifyChangeObservers(fiber: FlowStateMachineImpl<*>, addOrRemove: AddOrRemove) { + changesPublisher.bufferUntilDatabaseCommit().onNext(Change(fiber.logic, addOrRemove, fiber.id)) } }) @@ -125,7 +129,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, stateMachines.keys .map { it.logic } .filterIsInstance(flowClass) - .map { it to (it.fsm as FlowStateMachineImpl).resultFuture } + .map { it to (it.stateMachine as FlowStateMachineImpl).resultFuture } } } @@ -207,17 +211,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } serviceHub.networkService.addMessageHandler(sessionTopic) { message, reg -> executor.checkOnThread() - val sessionMessage = message.data.deserialize() - // TODO Look up the party with the full X.500 name instead of just the legal name - val otherParty = serviceHub.networkMapCache.getNodeByLegalName(message.peer.commonName)?.legalIdentity - if (otherParty != null) { - when (sessionMessage) { - is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, otherParty) - is SessionInit -> onSessionInit(sessionMessage, otherParty) - } - } else { - logger.error("Unknown peer ${message.peer} in $sessionMessage") - } + onSessionMessage(message) } } @@ -230,26 +224,40 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } } - private fun onExistingSessionMessage(message: ExistingSessionMessage, otherParty: Party) { + private fun onSessionMessage(message: ReceivedMessage) { + val sessionMessage = message.data.deserialize() + // TODO Look up the party with the full X.500 name instead of just the legal name + val sender = serviceHub.networkMapCache.getNodeByLegalName(message.peer.commonName)?.legalIdentity + if (sender != null) { + when (sessionMessage) { + is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender) + is SessionInit -> onSessionInit(sessionMessage, sender) + } + } else { + logger.error("Unknown peer ${message.peer} in $sessionMessage") + } + } + + private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) { val session = openSessions[message.recipientSessionId] if (session != null) { - session.psm.logger.trace { "Received $message on $session" } + session.fiber.logger.trace { "Received $message on $session" } if (message is SessionEnd) { openSessions.remove(message.recipientSessionId) } - session.receivedMessages += ReceivedSessionMessage(otherParty, message) + session.receivedMessages += ReceivedSessionMessage(sender, message) if (session.waitingForResponse) { // We only want to resume once, so immediately reset the flag. session.waitingForResponse = false - updateCheckpoint(session.psm) - resumeFiber(session.psm) + updateCheckpoint(session.fiber) + resumeFiber(session.fiber) } } else { val peerParty = recentlyClosedSessions.remove(message.recipientSessionId) if (peerParty != null) { if (message is SessionConfirm) { logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" } - sendSessionMessage(peerParty, SessionEnd(message.initiatedSessionId), null) + sendSessionMessage(peerParty, SessionEnd(message.initiatedSessionId)) } else { logger.trace { "Ignoring session end message for already closed session: $message" } } @@ -259,32 +267,32 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } } - private fun onSessionInit(sessionInit: SessionInit, otherParty: Party) { - logger.trace { "Received $sessionInit $otherParty" } + private fun onSessionInit(sessionInit: SessionInit, sender: Party) { + logger.trace { "Received $sessionInit $sender" } val otherPartySessionId = sessionInit.initiatorSessionId try { val markerClass = Class.forName(sessionInit.flowName) val flowFactory = serviceHub.getFlowFactory(markerClass) if (flowFactory != null) { - val flow = flowFactory(otherParty) - val psm = createFiber(flow) - val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(otherParty, otherPartySessionId)) + val flow = flowFactory(sender) + val fiber = createFiber(flow) + val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(sender, otherPartySessionId)) if (sessionInit.firstPayload != null) { - session.receivedMessages += ReceivedSessionMessage(otherParty, SessionData(session.ourSessionId, sessionInit.firstPayload)) + session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload)) } openSessions[session.ourSessionId] = session - psm.openSessions[Pair(flow, otherParty)] = session - updateCheckpoint(psm) - sendSessionMessage(otherParty, SessionConfirm(otherPartySessionId, session.ourSessionId), psm) - psm.logger.debug { "Initiated from $sessionInit on $session" } - startFiber(psm) + fiber.openSessions[Pair(flow, sender)] = session + updateCheckpoint(fiber) + sendSessionMessage(sender, SessionConfirm(otherPartySessionId, session.ourSessionId), fiber) + fiber.logger.debug { "Initiated from $sessionInit on $session" } + startFiber(fiber) } else { logger.warn("Unknown flow marker class in $sessionInit") - sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}"), null) + sendSessionMessage(sender, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}")) } } catch (e: Exception) { logger.warn("Received invalid $sessionInit", e) - sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Unable to establish session"), null) + sendSessionMessage(sender, SessionReject(otherPartySessionId, "Unable to establish session")) } } @@ -312,27 +320,27 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, return FlowStateMachineImpl(id, logic, scheduler).apply { initFiber(this) } } - private fun initFiber(psm: FlowStateMachineImpl<*>) { - psm.database = database - psm.serviceHub = serviceHub - psm.actionOnSuspend = { ioRequest -> - updateCheckpoint(psm) + private fun initFiber(fiber: FlowStateMachineImpl<*>) { + fiber.database = database + fiber.serviceHub = serviceHub + fiber.actionOnSuspend = { ioRequest -> + updateCheckpoint(fiber) // We commit on the fibers transaction that was copied across ThreadLocals during suspend // This will free up the ThreadLocal so on return the caller can carry on with other transactions - psm.commitTransaction() + fiber.commitTransaction() processIORequest(ioRequest) decrementLiveFibers() } - psm.actionOnEnd = { + fiber.actionOnEnd = { try { - psm.logic.progressTracker?.currentStep = ProgressTracker.DONE + fiber.logic.progressTracker?.currentStep = ProgressTracker.DONE mutex.locked { - stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) } + stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) } totalFinishedFlows.inc() unfinishedFibers.countDown() - notifyChangeObservers(psm, AddOrRemove.REMOVE) + notifyChangeObservers(fiber, AddOrRemove.REMOVE) } - endAllFiberSessions(psm) + endAllFiberSessions(fiber) } finally { decrementLiveFibers() } @@ -340,16 +348,16 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, mutex.locked { totalStartedFlows.inc() unfinishedFibers.countUp() - notifyChangeObservers(psm, AddOrRemove.ADD) + notifyChangeObservers(fiber, AddOrRemove.ADD) } } - private fun endAllFiberSessions(psm: FlowStateMachineImpl<*>) { + private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>) { openSessions.values.removeIf { session -> - if (session.psm == psm) { + if (session.fiber == fiber) { val initiatedState = session.state as? FlowSessionState.Initiated if (initiatedState != null) { - sendSessionMessage(initiatedState.peerParty, SessionEnd(initiatedState.peerSessionId), psm) + sendSessionMessage(initiatedState.peerParty, SessionEnd(initiatedState.peerSessionId), fiber) recentlyClosedSessions[session.ourSessionId] = initiatedState.peerParty } true @@ -405,10 +413,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, return fiber } - private fun updateCheckpoint(psm: FlowStateMachineImpl<*>) { - check(psm.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" } - val newCheckpoint = Checkpoint(serializeFiber(psm)) - val previousCheckpoint = mutex.locked { stateMachines.put(psm, newCheckpoint) } + private fun updateCheckpoint(fiber: FlowStateMachineImpl<*>) { + check(fiber.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" } + val newCheckpoint = Checkpoint(serializeFiber(fiber)) + val previousCheckpoint = mutex.locked { stateMachines.put(fiber, newCheckpoint) } if (previousCheckpoint != null) { checkpointStorage.removeCheckpoint(previousCheckpoint) } @@ -416,13 +424,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, checkpointingMeter.mark() } - private fun resumeFiber(psm: FlowStateMachineImpl<*>) { + private fun resumeFiber(fiber: FlowStateMachineImpl<*>) { // Avoid race condition when setting stopping to true and then checking liveFibers incrementLiveFibers() if (!stopping) executor.executeASAP { - psm.resume(scheduler) + fiber.resume(scheduler) } else { - psm.logger.debug("Not resuming as SMM is stopping.") + fiber.logger.debug("Not resuming as SMM is stopping.") decrementLiveFibers() } } @@ -432,51 +440,23 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, if (ioRequest.message is SessionInit) { openSessions[ioRequest.session.ourSessionId] = ioRequest.session } - sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.psm) + sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber) if (ioRequest !is ReceiveRequest<*>) { // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. - resumeFiber(ioRequest.session.psm) + resumeFiber(ioRequest.session.fiber) } } } - private fun sendSessionMessage(party: Party, message: SessionMessage, psm: FlowStateMachineImpl<*>?) { + private fun sendSessionMessage(party: Party, message: SessionMessage, fiber: FlowStateMachineImpl<*>? = null) { val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about party $party") val address = serviceHub.networkService.getAddressOfParty(partyInfo) - val logger = psm?.logger ?: logger + val logger = fiber?.logger ?: logger logger.debug { "Sending $message to party $party, address: $address" } serviceHub.networkService.send(sessionTopic, message, address) } - data class ReceivedSessionMessage(val sendingParty: Party, val message: M) - - interface SessionMessage - - interface ExistingSessionMessage : SessionMessage { - val recipientSessionId: Long - } - - data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage - - interface SessionInitResponse : ExistingSessionMessage - - data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse { - override val recipientSessionId: Long get() = initiatorSessionId - } - - data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse { - override val recipientSessionId: Long get() = initiatorSessionId - } - - data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage { - override fun toString(): String { - return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})" - } - } - - data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage - /** * [FlowSessionState] describes the session's state. * @@ -507,7 +487,6 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, @Volatile var waitingForResponse: Boolean = false ) { val receivedMessages = ConcurrentLinkedQueue>() - val psm: FlowStateMachineImpl<*> get() = flow.fsm as FlowStateMachineImpl<*> + val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*> } - } diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeProtocolTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeProtocolTests.kt index 8b4e5dbf23..06370a7449 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeProtocolTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeProtocolTests.kt @@ -415,10 +415,10 @@ class TwoPartyTradeFlowTests { private fun runBuyerAndSeller(assetToSell: StateAndRef): RunResult { val buyerFuture = bobNode.initiateSingleShotFlow(Seller::class) { otherParty -> Buyer(otherParty, notaryNode.info.notaryIdentity, 1000.DOLLARS, CommercialPaper.State::class.java) - }.map { it.fsm } + }.map { it.stateMachine } val seller = Seller(bobNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY) val sellerResultFuture = aliceNode.services.startFlow(seller).resultFuture - return RunResult(buyerFuture, sellerResultFuture, seller.fsm.id) + return RunResult(buyerFuture, sellerResultFuture, seller.stateMachine.id) } private fun LedgerDSL.runWithError( diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt index d66c4f93b9..d6ebf752b1 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt @@ -7,8 +7,8 @@ import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.issuedBy import net.corda.core.crypto.Party import net.corda.core.crypto.generateKeyPair +import net.corda.core.flows.FlowException import net.corda.core.flows.FlowLogic -import net.corda.core.flows.FlowSessionException import net.corda.core.getOrThrow import net.corda.core.random63BitValue import net.corda.core.serialization.OpaqueBytes @@ -17,7 +17,6 @@ import net.corda.flows.CashCommand import net.corda.flows.CashFlow import net.corda.flows.NotaryFlow import net.corda.node.services.persistence.checkpoints -import net.corda.node.services.statemachine.StateMachineManager.* import net.corda.node.utilities.databaseTransaction import net.corda.testing.expect import net.corda.testing.expectEvents @@ -215,15 +214,15 @@ class StateMachineManagerTests { assertSessionTransfers(node2, node1 sent sessionInit(SendFlow::class, payload) to node2, - node2 sent sessionConfirm() to node1, - node1 sent sessionEnd() to node2 + node2 sent sessionConfirm to node1, + node1 sent sessionEnd to node2 //There's no session end from the other flows as they're manually suspended ) assertSessionTransfers(node3, node1 sent sessionInit(SendFlow::class, payload) to node3, - node3 sent sessionConfirm() to node1, - node1 sent sessionEnd() to node3 + node3 sent sessionConfirm to node1, + node1 sent sessionEnd to node3 //There's no session end from the other flows as they're manually suspended ) @@ -248,16 +247,16 @@ class StateMachineManagerTests { assertSessionTransfers(node2, node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2, - node2 sent sessionConfirm() to node1, + node2 sent sessionConfirm to node1, node2 sent sessionData(node2Payload) to node1, - node2 sent sessionEnd() to node1 + node2 sent sessionEnd to node1 ) assertSessionTransfers(node3, node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node3, - node3 sent sessionConfirm() to node1, + node3 sent sessionConfirm to node1, node3 sent sessionData(node3Payload) to node1, - node3 sent sessionEnd() to node1 + node3 sent sessionEnd to node1 ) } @@ -269,11 +268,11 @@ class StateMachineManagerTests { assertSessionTransfers( node1 sent sessionInit(PingPongFlow::class, 10L) to node2, - node2 sent sessionConfirm() to node1, + node2 sent sessionConfirm to node1, node2 sent sessionData(20L) to node1, node1 sent sessionData(11L) to node2, node2 sent sessionData(21L) to node1, - node1 sent sessionEnd() to node2 + node1 sent sessionEnd to node2 ) } @@ -333,11 +332,11 @@ class StateMachineManagerTests { node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { ExceptionFlow } val future = node1.services.startFlow(ReceiveThenSuspendFlow(node2.info.legalIdentity)).resultFuture net.runNetwork() - assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowSessionException::class.java) + assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowException::class.java) assertSessionTransfers( node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2, - node2 sent sessionConfirm() to node1, - node2 sent sessionEnd() to node1 + node2 sent sessionConfirm to node1, + node2 sent sessionEnd to node1 ) } @@ -358,11 +357,11 @@ class StateMachineManagerTests { private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload) - private fun sessionConfirm() = SessionConfirm(0, 0) + private val sessionConfirm = SessionConfirm(0, 0) private fun sessionData(payload: Any) = SessionData(0, payload) - private fun sessionEnd() = SessionEnd(0) + private val sessionEnd = SessionEnd(0) private fun assertSessionTransfers(vararg expected: SessionTransfer) { assertThat(sessionTransfers).containsExactly(*expected) @@ -462,5 +461,4 @@ class StateMachineManagerTests { private object ExceptionFlow : FlowLogic() { override fun call(): Nothing = throw Exception() } - } diff --git a/samples/irs-demo/src/main/kotlin/net/corda/simulation/IRSSimulation.kt b/samples/irs-demo/src/main/kotlin/net/corda/simulation/IRSSimulation.kt index aaea02868e..4b67d341eb 100644 --- a/samples/irs-demo/src/main/kotlin/net/corda/simulation/IRSSimulation.kt +++ b/samples/irs-demo/src/main/kotlin/net/corda/simulation/IRSSimulation.kt @@ -120,7 +120,7 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten @Suppress("UNCHECKED_CAST") val acceptorTx = node2.initiateSingleShotFlow(Instigator::class) { Acceptor(it) }.flatMap { - (it.fsm as FlowStateMachine).resultFuture + (it.stateMachine as FlowStateMachine).resultFuture } showProgressFor(listOf(node1, node2)) diff --git a/samples/irs-demo/src/main/kotlin/net/corda/simulation/TradeSimulation.kt b/samples/irs-demo/src/main/kotlin/net/corda/simulation/TradeSimulation.kt index 35766073ee..036fbfc92b 100644 --- a/samples/irs-demo/src/main/kotlin/net/corda/simulation/TradeSimulation.kt +++ b/samples/irs-demo/src/main/kotlin/net/corda/simulation/TradeSimulation.kt @@ -53,7 +53,7 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo @Suppress("UNCHECKED_CAST") val buyerFuture = buyer.initiateSingleShotFlow(Seller::class) { Buyer(it, notary.info.notaryIdentity, amount, CommercialPaper.State::class.java) - }.flatMap { (it.fsm as FlowStateMachine).resultFuture } + }.flatMap { (it.stateMachine as FlowStateMachine).resultFuture } val sellerKey = seller.services.legalIdentityKey val sellerFlow = Seller( diff --git a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt index eeef8c7313..0f3feab89d 100644 --- a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt +++ b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt @@ -17,7 +17,9 @@ import net.corda.core.then import net.corda.core.utilities.ProgressTracker import net.corda.netmap.VisualiserViewModel.Style import net.corda.node.services.network.NetworkMapService -import net.corda.node.services.statemachine.StateMachineManager +import net.corda.node.services.statemachine.SessionConfirm +import net.corda.node.services.statemachine.SessionEnd +import net.corda.node.services.statemachine.SessionInit import net.corda.simulation.IRSSimulation import net.corda.simulation.Simulation import net.corda.testing.node.InMemoryMessagingNetwork @@ -349,13 +351,12 @@ class NetworkMapVisualiser : Application() { // Network map push acknowledgements are boring. if (NetworkMapService.PUSH_ACK_FLOW_TOPIC in transfer.message.topicSession.topic) return false val message = transfer.message.data.deserialize() - val messageClassType = message.javaClass.name - when (messageClassType) { - StateMachineManager.SessionEnd::class.java.name -> return false - StateMachineManager.SessionConfirm::class.java.name -> return false - StateMachineManager.SessionInit::class.java.name -> if ((message as StateMachineManager.SessionInit).firstPayload == null) return false + return when (message) { + is SessionEnd -> false + is SessionConfirm -> false + is SessionInit -> message.firstPayload != null + else -> true } - return true } }