mirror of
https://github.com/corda/corda.git
synced 2024-12-24 15:16:45 +00:00
Some clean up of the flow code
This commit is contained in:
parent
95a33168d8
commit
e589031d4b
@ -27,18 +27,18 @@ import rx.Observable
|
|||||||
*/
|
*/
|
||||||
abstract class FlowLogic<out T> {
|
abstract class FlowLogic<out T> {
|
||||||
|
|
||||||
/** Reference to the [Fiber] instance that is the top level controller for the entire flow. */
|
/** Reference to the [FlowStateMachine] instance that is the top level controller for the entire flow. */
|
||||||
lateinit var fsm: FlowStateMachine<*>
|
lateinit var stateMachine: FlowStateMachine<*>
|
||||||
|
|
||||||
/** This is where you should log things to. */
|
/** This is where you should log things to. */
|
||||||
val logger: Logger get() = fsm.logger
|
val logger: Logger get() = stateMachine.logger
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is
|
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is
|
||||||
* only available once the flow has started, which means it cannnot be accessed in the constructor. Either
|
* only available once the flow has started, which means it cannnot be accessed in the constructor. Either
|
||||||
* access this lazily or from inside [call].
|
* access this lazily or from inside [call].
|
||||||
*/
|
*/
|
||||||
val serviceHub: ServiceHub get() = fsm.serviceHub
|
val serviceHub: ServiceHub get() = stateMachine.serviceHub
|
||||||
|
|
||||||
private var sessionFlow: FlowLogic<*> = this
|
private var sessionFlow: FlowLogic<*> = this
|
||||||
|
|
||||||
@ -56,19 +56,19 @@ abstract class FlowLogic<out T> {
|
|||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun <T : Any> sendAndReceive(receiveType: Class<T>, otherParty: Party, payload: Any): UntrustworthyData<T> {
|
fun <T : Any> sendAndReceive(receiveType: Class<T>, otherParty: Party, payload: Any): UntrustworthyData<T> {
|
||||||
return fsm.sendAndReceive(otherParty, payload, receiveType, sessionFlow)
|
return stateMachine.sendAndReceive(receiveType, otherParty, payload, sessionFlow)
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fun <reified T : Any> receive(otherParty: Party): UntrustworthyData<T> = receive(T::class.java, otherParty)
|
inline fun <reified T : Any> receive(otherParty: Party): UntrustworthyData<T> = receive(T::class.java, otherParty)
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party): UntrustworthyData<T> {
|
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party): UntrustworthyData<T> {
|
||||||
return fsm.receive(otherParty, receiveType, sessionFlow)
|
return stateMachine.receive(receiveType, otherParty, sessionFlow)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun send(otherParty: Party, payload: Any) {
|
fun send(otherParty: Party, payload: Any) {
|
||||||
fsm.send(otherParty, payload, sessionFlow)
|
stateMachine.send(otherParty, payload, sessionFlow)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -82,7 +82,7 @@ abstract class FlowLogic<out T> {
|
|||||||
// TODO shareParentSessions is a bit too low-level and perhaps can be expresed in a better way
|
// TODO shareParentSessions is a bit too low-level and perhaps can be expresed in a better way
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun <R> subFlow(subLogic: FlowLogic<R>, shareParentSessions: Boolean = false): R {
|
fun <R> subFlow(subLogic: FlowLogic<R>, shareParentSessions: Boolean = false): R {
|
||||||
subLogic.fsm = fsm
|
subLogic.stateMachine = stateMachine
|
||||||
maybeWireUpProgressTracking(subLogic)
|
maybeWireUpProgressTracking(subLogic)
|
||||||
if (shareParentSessions) {
|
if (shareParentSessions) {
|
||||||
subLogic.sessionFlow = this
|
subLogic.sessionFlow = this
|
||||||
@ -127,5 +127,4 @@ abstract class FlowLogic<out T> {
|
|||||||
Pair(it.currentStep.toString(), it.changes.map { it.toString() })
|
Pair(it.currentStep.toString(), it.changes.map { it.toString() })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -28,13 +28,13 @@ data class StateMachineRunId private constructor(val uuid: UUID) {
|
|||||||
*/
|
*/
|
||||||
interface FlowStateMachine<R> {
|
interface FlowStateMachine<R> {
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun <T : Any> sendAndReceive(otherParty: Party,
|
fun <T : Any> sendAndReceive(receiveType: Class<T>,
|
||||||
|
otherParty: Party,
|
||||||
payload: Any,
|
payload: Any,
|
||||||
receiveType: Class<T>,
|
|
||||||
sessionFlow: FlowLogic<*>): UntrustworthyData<T>
|
sessionFlow: FlowLogic<*>): UntrustworthyData<T>
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun <T : Any> receive(otherParty: Party, receiveType: Class<T>, sessionFlow: FlowLogic<*>): UntrustworthyData<T>
|
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData<T>
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>)
|
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>)
|
||||||
@ -48,4 +48,4 @@ interface FlowStateMachine<R> {
|
|||||||
val resultFuture: ListenableFuture<R>
|
val resultFuture: ListenableFuture<R>
|
||||||
}
|
}
|
||||||
|
|
||||||
class FlowSessionException(message: String) : Exception(message)
|
class FlowException(message: String) : RuntimeException(message)
|
||||||
|
@ -59,7 +59,7 @@ class CashFlow(val command: CashCommand, override val progressTracker: ProgressT
|
|||||||
val flow = FinalityFlow(tx, setOf(req.recipient))
|
val flow = FinalityFlow(tx, setOf(req.recipient))
|
||||||
subFlow(flow)
|
subFlow(flow)
|
||||||
return CashFlowResult.Success(
|
return CashFlowResult.Success(
|
||||||
fsm.id,
|
stateMachine.id,
|
||||||
tx,
|
tx,
|
||||||
"Cash payment transaction generated"
|
"Cash payment transaction generated"
|
||||||
)
|
)
|
||||||
@ -95,7 +95,7 @@ class CashFlow(val command: CashCommand, override val progressTracker: ProgressT
|
|||||||
val tx = builder.toSignedTransaction(checkSufficientSignatures = false)
|
val tx = builder.toSignedTransaction(checkSufficientSignatures = false)
|
||||||
subFlow(FinalityFlow(tx, participants))
|
subFlow(FinalityFlow(tx, participants))
|
||||||
return CashFlowResult.Success(
|
return CashFlowResult.Success(
|
||||||
fsm.id,
|
stateMachine.id,
|
||||||
tx,
|
tx,
|
||||||
"Cash destruction transaction generated"
|
"Cash destruction transaction generated"
|
||||||
)
|
)
|
||||||
@ -116,7 +116,7 @@ class CashFlow(val command: CashCommand, override val progressTracker: ProgressT
|
|||||||
// Issuance transactions do not need to be notarised, so we can skip directly to broadcasting it
|
// Issuance transactions do not need to be notarised, so we can skip directly to broadcasting it
|
||||||
subFlow(BroadcastTransactionFlow(tx, setOf(req.recipient)))
|
subFlow(BroadcastTransactionFlow(tx, setOf(req.recipient)))
|
||||||
return CashFlowResult.Success(
|
return CashFlowResult.Success(
|
||||||
fsm.id,
|
stateMachine.id,
|
||||||
tx,
|
tx,
|
||||||
"Cash issuance completed"
|
"Cash issuance completed"
|
||||||
)
|
)
|
||||||
|
@ -54,7 +54,7 @@ class IssuerFlowTest {
|
|||||||
private fun runIssuerAndIssueRequester(amount: Amount<Currency>, issueToPartyAndRef: PartyAndReference) : RunResult {
|
private fun runIssuerAndIssueRequester(amount: Amount<Currency>, issueToPartyAndRef: PartyAndReference) : RunResult {
|
||||||
val issuerFuture = bankOfCordaNode.initiateSingleShotFlow(IssuerFlow.IssuanceRequester::class) {
|
val issuerFuture = bankOfCordaNode.initiateSingleShotFlow(IssuerFlow.IssuanceRequester::class) {
|
||||||
otherParty -> IssuerFlow.Issuer(issueToPartyAndRef.party)
|
otherParty -> IssuerFlow.Issuer(issueToPartyAndRef.party)
|
||||||
}.map { it.fsm }
|
}.map { it.stateMachine }
|
||||||
|
|
||||||
val issueRequest = IssuanceRequester(amount, issueToPartyAndRef.party, issueToPartyAndRef.reference, bankOfCordaNode.info.legalIdentity)
|
val issueRequest = IssuanceRequester(amount, issueToPartyAndRef.party, issueToPartyAndRef.reference, bankOfCordaNode.info.legalIdentity)
|
||||||
val issueRequestResultFuture = bankClientNode.services.startFlow(issueRequest).resultFuture
|
val issueRequestResultFuture = bankClientNode.services.startFlow(issueRequest).resultFuture
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package net.corda.node.services.statemachine
|
package net.corda.node.services.statemachine
|
||||||
|
|
||||||
import net.corda.node.services.statemachine.StateMachineManager.FlowSession
|
import net.corda.node.services.statemachine.StateMachineManager.FlowSession
|
||||||
import net.corda.node.services.statemachine.StateMachineManager.SessionMessage
|
|
||||||
|
|
||||||
// TODO revisit when Kotlin 1.1 is released and data classes can extend other classes
|
// TODO revisit when Kotlin 1.1 is released and data classes can extend other classes
|
||||||
interface FlowIORequest {
|
interface FlowIORequest {
|
@ -7,15 +7,16 @@ import co.paralleluniverse.strands.Strand
|
|||||||
import com.google.common.util.concurrent.ListenableFuture
|
import com.google.common.util.concurrent.ListenableFuture
|
||||||
import com.google.common.util.concurrent.SettableFuture
|
import com.google.common.util.concurrent.SettableFuture
|
||||||
import net.corda.core.crypto.Party
|
import net.corda.core.crypto.Party
|
||||||
|
import net.corda.core.flows.FlowException
|
||||||
import net.corda.core.flows.FlowLogic
|
import net.corda.core.flows.FlowLogic
|
||||||
import net.corda.core.flows.FlowSessionException
|
|
||||||
import net.corda.core.flows.FlowStateMachine
|
import net.corda.core.flows.FlowStateMachine
|
||||||
import net.corda.core.flows.StateMachineRunId
|
import net.corda.core.flows.StateMachineRunId
|
||||||
import net.corda.core.random63BitValue
|
import net.corda.core.random63BitValue
|
||||||
import net.corda.core.utilities.UntrustworthyData
|
import net.corda.core.utilities.UntrustworthyData
|
||||||
import net.corda.core.utilities.trace
|
import net.corda.core.utilities.trace
|
||||||
import net.corda.node.services.api.ServiceHubInternal
|
import net.corda.node.services.api.ServiceHubInternal
|
||||||
import net.corda.node.services.statemachine.StateMachineManager.*
|
import net.corda.node.services.statemachine.StateMachineManager.FlowSession
|
||||||
|
import net.corda.node.services.statemachine.StateMachineManager.FlowSessionState
|
||||||
import net.corda.node.utilities.StrandLocalTransactionManager
|
import net.corda.node.utilities.StrandLocalTransactionManager
|
||||||
import net.corda.node.utilities.createDatabaseTransaction
|
import net.corda.node.utilities.createDatabaseTransaction
|
||||||
import net.corda.node.utilities.databaseTransaction
|
import net.corda.node.utilities.databaseTransaction
|
||||||
@ -31,7 +32,6 @@ import java.util.concurrent.ExecutionException
|
|||||||
class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||||
val logic: FlowLogic<R>,
|
val logic: FlowLogic<R>,
|
||||||
scheduler: FiberScheduler) : Fiber<R>("flow", scheduler), FlowStateMachine<R> {
|
scheduler: FiberScheduler) : Fiber<R>("flow", scheduler), FlowStateMachine<R> {
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
// Used to work around a small limitation in Quasar.
|
// Used to work around a small limitation in Quasar.
|
||||||
private val QUASAR_UNBLOCKER = run {
|
private val QUASAR_UNBLOCKER = run {
|
||||||
@ -76,7 +76,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
internal val openSessions = HashMap<Pair<FlowLogic<*>, Party>, FlowSession>()
|
internal val openSessions = HashMap<Pair<FlowLogic<*>, Party>, FlowSession>()
|
||||||
|
|
||||||
init {
|
init {
|
||||||
logic.fsm = this
|
logic.stateMachine = this
|
||||||
name = id.toString()
|
name = id.toString()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,9 +120,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
override fun <T : Any> sendAndReceive(otherParty: Party,
|
override fun <T : Any> sendAndReceive(receiveType: Class<T>,
|
||||||
|
otherParty: Party,
|
||||||
payload: Any,
|
payload: Any,
|
||||||
receiveType: Class<T>,
|
|
||||||
sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
|
sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
|
||||||
val (session, new) = getSession(otherParty, sessionFlow, payload)
|
val (session, new) = getSession(otherParty, sessionFlow, payload)
|
||||||
val receivedSessionData = if (new) {
|
val receivedSessionData = if (new) {
|
||||||
@ -132,16 +132,15 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
val sendSessionData = createSessionData(session, payload)
|
val sendSessionData = createSessionData(session, payload)
|
||||||
sendAndReceiveInternal<SessionData>(session, sendSessionData)
|
sendAndReceiveInternal<SessionData>(session, sendSessionData)
|
||||||
}
|
}
|
||||||
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
|
return receivedSessionData.checkPayloadIs(receiveType)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
override fun <T : Any> receive(otherParty: Party,
|
override fun <T : Any> receive(receiveType: Class<T>,
|
||||||
receiveType: Class<T>,
|
otherParty: Party,
|
||||||
sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
|
sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
|
||||||
val session = getSession(otherParty, sessionFlow, null).first
|
val session = getSession(otherParty, sessionFlow, null).first
|
||||||
val receivedSessionData = receiveInternal<SessionData>(session)
|
return receiveInternal<SessionData>(session).checkPayloadIs(receiveType)
|
||||||
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
@ -156,8 +155,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
private fun createSessionData(session: FlowSession, payload: Any): SessionData {
|
private fun createSessionData(session: FlowSession, payload: Any): SessionData {
|
||||||
val sessionState = session.state
|
val sessionState = session.state
|
||||||
val peerSessionId = when (sessionState) {
|
val peerSessionId = when (sessionState) {
|
||||||
is StateMachineManager.FlowSessionState.Initiating -> throw IllegalStateException("We've somehow held onto an unconfirmed session: $session")
|
is FlowSessionState.Initiating -> throw IllegalStateException("We've somehow held onto an unconfirmed session: $session")
|
||||||
is StateMachineManager.FlowSessionState.Initiated -> sessionState.peerSessionId
|
is FlowSessionState.Initiated -> sessionState.peerSessionId
|
||||||
}
|
}
|
||||||
return SessionData(peerSessionId, payload)
|
return SessionData(peerSessionId, payload)
|
||||||
}
|
}
|
||||||
@ -167,16 +166,13 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
suspend(SendOnly(session, message))
|
suspend(SendOnly(session, message))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
private inline fun <reified M : ExistingSessionMessage> receiveInternal(session: FlowSession): ReceivedSessionMessage<M> {
|
||||||
private inline fun <reified M : SessionMessage> receiveInternal(session: FlowSession): M {
|
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java))
|
||||||
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)).message
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private inline fun <reified M : SessionMessage> sendAndReceiveInternal(session: FlowSession, message: SessionMessage): M {
|
private inline fun <reified M : ExistingSessionMessage> sendAndReceiveInternal(
|
||||||
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)).message
|
session: FlowSession,
|
||||||
}
|
message: SessionMessage): ReceivedSessionMessage<M> {
|
||||||
|
|
||||||
private inline fun <reified M : SessionMessage> sendAndReceiveInternalWithParty(session: FlowSession, message: SessionMessage): ReceivedSessionMessage<M> {
|
|
||||||
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java))
|
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,20 +199,21 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
openSessions[Pair(sessionFlow, otherParty)] = session
|
openSessions[Pair(sessionFlow, otherParty)] = session
|
||||||
val counterpartyFlow = sessionFlow.getCounterpartyMarker(otherParty).name
|
val counterpartyFlow = sessionFlow.getCounterpartyMarker(otherParty).name
|
||||||
val sessionInit = SessionInit(session.ourSessionId, counterpartyFlow, firstPayload)
|
val sessionInit = SessionInit(session.ourSessionId, counterpartyFlow, firstPayload)
|
||||||
val (peerParty, sessionInitResponse) = sendAndReceiveInternalWithParty<SessionInitResponse>(session, sessionInit)
|
val (peerParty, sessionInitResponse) = sendAndReceiveInternal<SessionInitResponse>(session, sessionInit)
|
||||||
if (sessionInitResponse is SessionConfirm) {
|
if (sessionInitResponse is SessionConfirm) {
|
||||||
require(session.state is FlowSessionState.Initiating)
|
require(session.state is FlowSessionState.Initiating)
|
||||||
session.state = FlowSessionState.Initiated(peerParty, sessionInitResponse.initiatedSessionId)
|
session.state = FlowSessionState.Initiated(peerParty, sessionInitResponse.initiatedSessionId)
|
||||||
return session
|
return session
|
||||||
} else {
|
} else {
|
||||||
sessionInitResponse as SessionReject
|
sessionInitResponse as SessionReject
|
||||||
throw FlowSessionException("Party $otherParty rejected session attempt: ${sessionInitResponse.errorMessage}")
|
throw FlowException("Party $otherParty rejected session request: ${sessionInitResponse.errorMessage}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
private fun <M : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
|
private fun <M : ExistingSessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
|
||||||
fun getReceivedMessage(): ReceivedSessionMessage<ExistingSessionMessage>? = receiveRequest.session.receivedMessages.poll()
|
val session = receiveRequest.session
|
||||||
|
fun getReceivedMessage(): ReceivedSessionMessage<ExistingSessionMessage>? = session.receivedMessages.poll()
|
||||||
|
|
||||||
val polledMessage = getReceivedMessage()
|
val polledMessage = getReceivedMessage()
|
||||||
val receivedMessage = if (polledMessage != null) {
|
val receivedMessage = if (polledMessage != null) {
|
||||||
@ -228,17 +225,21 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
} else {
|
} else {
|
||||||
// Suspend while we wait for a receive
|
// Suspend while we wait for a receive
|
||||||
suspend(receiveRequest)
|
suspend(receiveRequest)
|
||||||
getReceivedMessage()
|
getReceivedMessage() ?:
|
||||||
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $receiveRequest")
|
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead " +
|
||||||
|
"got nothing: $receiveRequest")
|
||||||
}
|
}
|
||||||
|
|
||||||
if (receivedMessage.message is SessionEnd) {
|
if (receivedMessage.message is SessionEnd) {
|
||||||
openSessions.values.remove(receiveRequest.session)
|
openSessions.values.remove(session)
|
||||||
throw FlowSessionException("Counterparty on ${receiveRequest.session.state.sendToParty} has prematurely ended on $receiveRequest")
|
throw FlowException("Party ${session.state.sendToParty} has ended their flow but we were expecting to " +
|
||||||
|
"receive ${receiveRequest.receiveType.simpleName} from them")
|
||||||
} else if (receiveRequest.receiveType.isInstance(receivedMessage.message)) {
|
} else if (receiveRequest.receiveType.isInstance(receivedMessage.message)) {
|
||||||
return ReceivedSessionMessage(receivedMessage.sendingParty, receiveRequest.receiveType.cast(receivedMessage.message))
|
@Suppress("UNCHECKED_CAST")
|
||||||
|
return receivedMessage as ReceivedSessionMessage<M>
|
||||||
} else {
|
} else {
|
||||||
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $receiveRequest")
|
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead got " +
|
||||||
|
"${receivedMessage.message}: $receiveRequest")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -292,5 +293,4 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
logger.error("Error during resume", t)
|
logger.error("Error during resume", t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,43 @@
|
|||||||
|
package net.corda.node.services.statemachine
|
||||||
|
|
||||||
|
import net.corda.core.abbreviate
|
||||||
|
import net.corda.core.crypto.Party
|
||||||
|
import net.corda.core.flows.FlowException
|
||||||
|
import net.corda.core.utilities.UntrustworthyData
|
||||||
|
|
||||||
|
interface SessionMessage
|
||||||
|
|
||||||
|
interface ExistingSessionMessage : SessionMessage {
|
||||||
|
val recipientSessionId: Long
|
||||||
|
}
|
||||||
|
|
||||||
|
data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage
|
||||||
|
|
||||||
|
interface SessionInitResponse : ExistingSessionMessage
|
||||||
|
|
||||||
|
data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse {
|
||||||
|
override val recipientSessionId: Long get() = initiatorSessionId
|
||||||
|
}
|
||||||
|
|
||||||
|
data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse {
|
||||||
|
override val recipientSessionId: Long get() = initiatorSessionId
|
||||||
|
}
|
||||||
|
|
||||||
|
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage {
|
||||||
|
override fun toString(): String {
|
||||||
|
return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage
|
||||||
|
|
||||||
|
data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M)
|
||||||
|
|
||||||
|
fun <T> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
|
||||||
|
if (type.isInstance(message.payload)) {
|
||||||
|
return UntrustworthyData(type.cast(message.payload))
|
||||||
|
} else {
|
||||||
|
throw FlowException("We were expecting a ${type.name} from $sender but we instead got a " +
|
||||||
|
"${message.payload.javaClass.name} (${message.payload})")
|
||||||
|
}
|
||||||
|
}
|
@ -9,15 +9,19 @@ import com.esotericsoftware.kryo.Kryo
|
|||||||
import com.google.common.annotations.VisibleForTesting
|
import com.google.common.annotations.VisibleForTesting
|
||||||
import com.google.common.util.concurrent.ListenableFuture
|
import com.google.common.util.concurrent.ListenableFuture
|
||||||
import kotlinx.support.jdk8.collections.removeIf
|
import kotlinx.support.jdk8.collections.removeIf
|
||||||
import net.corda.core.*
|
import net.corda.core.ThreadBox
|
||||||
|
import net.corda.core.bufferUntilSubscribed
|
||||||
import net.corda.core.crypto.Party
|
import net.corda.core.crypto.Party
|
||||||
import net.corda.core.crypto.commonName
|
import net.corda.core.crypto.commonName
|
||||||
import net.corda.core.flows.FlowLogic
|
import net.corda.core.flows.FlowLogic
|
||||||
import net.corda.core.flows.FlowStateMachine
|
import net.corda.core.flows.FlowStateMachine
|
||||||
import net.corda.core.flows.StateMachineRunId
|
import net.corda.core.flows.StateMachineRunId
|
||||||
|
import net.corda.core.messaging.ReceivedMessage
|
||||||
import net.corda.core.messaging.TopicSession
|
import net.corda.core.messaging.TopicSession
|
||||||
import net.corda.core.messaging.send
|
import net.corda.core.messaging.send
|
||||||
|
import net.corda.core.random63BitValue
|
||||||
import net.corda.core.serialization.*
|
import net.corda.core.serialization.*
|
||||||
|
import net.corda.core.then
|
||||||
import net.corda.core.utilities.ProgressTracker
|
import net.corda.core.utilities.ProgressTracker
|
||||||
import net.corda.core.utilities.debug
|
import net.corda.core.utilities.debug
|
||||||
import net.corda.core.utilities.loggerFor
|
import net.corda.core.utilities.loggerFor
|
||||||
@ -89,8 +93,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
val stateMachines = LinkedHashMap<FlowStateMachineImpl<*>, Checkpoint>()
|
val stateMachines = LinkedHashMap<FlowStateMachineImpl<*>, Checkpoint>()
|
||||||
val changesPublisher = PublishSubject.create<Change>()
|
val changesPublisher = PublishSubject.create<Change>()
|
||||||
|
|
||||||
fun notifyChangeObservers(psm: FlowStateMachineImpl<*>, addOrRemove: AddOrRemove) {
|
fun notifyChangeObservers(fiber: FlowStateMachineImpl<*>, addOrRemove: AddOrRemove) {
|
||||||
changesPublisher.bufferUntilDatabaseCommit().onNext(Change(psm.logic, addOrRemove, psm.id))
|
changesPublisher.bufferUntilDatabaseCommit().onNext(Change(fiber.logic, addOrRemove, fiber.id))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -125,7 +129,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
stateMachines.keys
|
stateMachines.keys
|
||||||
.map { it.logic }
|
.map { it.logic }
|
||||||
.filterIsInstance(flowClass)
|
.filterIsInstance(flowClass)
|
||||||
.map { it to (it.fsm as FlowStateMachineImpl<T>).resultFuture }
|
.map { it to (it.stateMachine as FlowStateMachineImpl<T>).resultFuture }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,17 +211,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
}
|
}
|
||||||
serviceHub.networkService.addMessageHandler(sessionTopic) { message, reg ->
|
serviceHub.networkService.addMessageHandler(sessionTopic) { message, reg ->
|
||||||
executor.checkOnThread()
|
executor.checkOnThread()
|
||||||
val sessionMessage = message.data.deserialize<SessionMessage>()
|
onSessionMessage(message)
|
||||||
// TODO Look up the party with the full X.500 name instead of just the legal name
|
|
||||||
val otherParty = serviceHub.networkMapCache.getNodeByLegalName(message.peer.commonName)?.legalIdentity
|
|
||||||
if (otherParty != null) {
|
|
||||||
when (sessionMessage) {
|
|
||||||
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, otherParty)
|
|
||||||
is SessionInit -> onSessionInit(sessionMessage, otherParty)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.error("Unknown peer ${message.peer} in $sessionMessage")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,26 +224,40 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun onExistingSessionMessage(message: ExistingSessionMessage, otherParty: Party) {
|
private fun onSessionMessage(message: ReceivedMessage) {
|
||||||
|
val sessionMessage = message.data.deserialize<SessionMessage>()
|
||||||
|
// TODO Look up the party with the full X.500 name instead of just the legal name
|
||||||
|
val sender = serviceHub.networkMapCache.getNodeByLegalName(message.peer.commonName)?.legalIdentity
|
||||||
|
if (sender != null) {
|
||||||
|
when (sessionMessage) {
|
||||||
|
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender)
|
||||||
|
is SessionInit -> onSessionInit(sessionMessage, sender)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.error("Unknown peer ${message.peer} in $sessionMessage")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) {
|
||||||
val session = openSessions[message.recipientSessionId]
|
val session = openSessions[message.recipientSessionId]
|
||||||
if (session != null) {
|
if (session != null) {
|
||||||
session.psm.logger.trace { "Received $message on $session" }
|
session.fiber.logger.trace { "Received $message on $session" }
|
||||||
if (message is SessionEnd) {
|
if (message is SessionEnd) {
|
||||||
openSessions.remove(message.recipientSessionId)
|
openSessions.remove(message.recipientSessionId)
|
||||||
}
|
}
|
||||||
session.receivedMessages += ReceivedSessionMessage(otherParty, message)
|
session.receivedMessages += ReceivedSessionMessage(sender, message)
|
||||||
if (session.waitingForResponse) {
|
if (session.waitingForResponse) {
|
||||||
// We only want to resume once, so immediately reset the flag.
|
// We only want to resume once, so immediately reset the flag.
|
||||||
session.waitingForResponse = false
|
session.waitingForResponse = false
|
||||||
updateCheckpoint(session.psm)
|
updateCheckpoint(session.fiber)
|
||||||
resumeFiber(session.psm)
|
resumeFiber(session.fiber)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
val peerParty = recentlyClosedSessions.remove(message.recipientSessionId)
|
val peerParty = recentlyClosedSessions.remove(message.recipientSessionId)
|
||||||
if (peerParty != null) {
|
if (peerParty != null) {
|
||||||
if (message is SessionConfirm) {
|
if (message is SessionConfirm) {
|
||||||
logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" }
|
logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" }
|
||||||
sendSessionMessage(peerParty, SessionEnd(message.initiatedSessionId), null)
|
sendSessionMessage(peerParty, SessionEnd(message.initiatedSessionId))
|
||||||
} else {
|
} else {
|
||||||
logger.trace { "Ignoring session end message for already closed session: $message" }
|
logger.trace { "Ignoring session end message for already closed session: $message" }
|
||||||
}
|
}
|
||||||
@ -259,32 +267,32 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun onSessionInit(sessionInit: SessionInit, otherParty: Party) {
|
private fun onSessionInit(sessionInit: SessionInit, sender: Party) {
|
||||||
logger.trace { "Received $sessionInit $otherParty" }
|
logger.trace { "Received $sessionInit $sender" }
|
||||||
val otherPartySessionId = sessionInit.initiatorSessionId
|
val otherPartySessionId = sessionInit.initiatorSessionId
|
||||||
try {
|
try {
|
||||||
val markerClass = Class.forName(sessionInit.flowName)
|
val markerClass = Class.forName(sessionInit.flowName)
|
||||||
val flowFactory = serviceHub.getFlowFactory(markerClass)
|
val flowFactory = serviceHub.getFlowFactory(markerClass)
|
||||||
if (flowFactory != null) {
|
if (flowFactory != null) {
|
||||||
val flow = flowFactory(otherParty)
|
val flow = flowFactory(sender)
|
||||||
val psm = createFiber(flow)
|
val fiber = createFiber(flow)
|
||||||
val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(otherParty, otherPartySessionId))
|
val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(sender, otherPartySessionId))
|
||||||
if (sessionInit.firstPayload != null) {
|
if (sessionInit.firstPayload != null) {
|
||||||
session.receivedMessages += ReceivedSessionMessage(otherParty, SessionData(session.ourSessionId, sessionInit.firstPayload))
|
session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload))
|
||||||
}
|
}
|
||||||
openSessions[session.ourSessionId] = session
|
openSessions[session.ourSessionId] = session
|
||||||
psm.openSessions[Pair(flow, otherParty)] = session
|
fiber.openSessions[Pair(flow, sender)] = session
|
||||||
updateCheckpoint(psm)
|
updateCheckpoint(fiber)
|
||||||
sendSessionMessage(otherParty, SessionConfirm(otherPartySessionId, session.ourSessionId), psm)
|
sendSessionMessage(sender, SessionConfirm(otherPartySessionId, session.ourSessionId), fiber)
|
||||||
psm.logger.debug { "Initiated from $sessionInit on $session" }
|
fiber.logger.debug { "Initiated from $sessionInit on $session" }
|
||||||
startFiber(psm)
|
startFiber(fiber)
|
||||||
} else {
|
} else {
|
||||||
logger.warn("Unknown flow marker class in $sessionInit")
|
logger.warn("Unknown flow marker class in $sessionInit")
|
||||||
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}"), null)
|
sendSessionMessage(sender, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}"))
|
||||||
}
|
}
|
||||||
} catch (e: Exception) {
|
} catch (e: Exception) {
|
||||||
logger.warn("Received invalid $sessionInit", e)
|
logger.warn("Received invalid $sessionInit", e)
|
||||||
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Unable to establish session"), null)
|
sendSessionMessage(sender, SessionReject(otherPartySessionId, "Unable to establish session"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,27 +320,27 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
return FlowStateMachineImpl(id, logic, scheduler).apply { initFiber(this) }
|
return FlowStateMachineImpl(id, logic, scheduler).apply { initFiber(this) }
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun initFiber(psm: FlowStateMachineImpl<*>) {
|
private fun initFiber(fiber: FlowStateMachineImpl<*>) {
|
||||||
psm.database = database
|
fiber.database = database
|
||||||
psm.serviceHub = serviceHub
|
fiber.serviceHub = serviceHub
|
||||||
psm.actionOnSuspend = { ioRequest ->
|
fiber.actionOnSuspend = { ioRequest ->
|
||||||
updateCheckpoint(psm)
|
updateCheckpoint(fiber)
|
||||||
// We commit on the fibers transaction that was copied across ThreadLocals during suspend
|
// We commit on the fibers transaction that was copied across ThreadLocals during suspend
|
||||||
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
|
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
|
||||||
psm.commitTransaction()
|
fiber.commitTransaction()
|
||||||
processIORequest(ioRequest)
|
processIORequest(ioRequest)
|
||||||
decrementLiveFibers()
|
decrementLiveFibers()
|
||||||
}
|
}
|
||||||
psm.actionOnEnd = {
|
fiber.actionOnEnd = {
|
||||||
try {
|
try {
|
||||||
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
fiber.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||||
mutex.locked {
|
mutex.locked {
|
||||||
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
|
stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) }
|
||||||
totalFinishedFlows.inc()
|
totalFinishedFlows.inc()
|
||||||
unfinishedFibers.countDown()
|
unfinishedFibers.countDown()
|
||||||
notifyChangeObservers(psm, AddOrRemove.REMOVE)
|
notifyChangeObservers(fiber, AddOrRemove.REMOVE)
|
||||||
}
|
}
|
||||||
endAllFiberSessions(psm)
|
endAllFiberSessions(fiber)
|
||||||
} finally {
|
} finally {
|
||||||
decrementLiveFibers()
|
decrementLiveFibers()
|
||||||
}
|
}
|
||||||
@ -340,16 +348,16 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
mutex.locked {
|
mutex.locked {
|
||||||
totalStartedFlows.inc()
|
totalStartedFlows.inc()
|
||||||
unfinishedFibers.countUp()
|
unfinishedFibers.countUp()
|
||||||
notifyChangeObservers(psm, AddOrRemove.ADD)
|
notifyChangeObservers(fiber, AddOrRemove.ADD)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun endAllFiberSessions(psm: FlowStateMachineImpl<*>) {
|
private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>) {
|
||||||
openSessions.values.removeIf { session ->
|
openSessions.values.removeIf { session ->
|
||||||
if (session.psm == psm) {
|
if (session.fiber == fiber) {
|
||||||
val initiatedState = session.state as? FlowSessionState.Initiated
|
val initiatedState = session.state as? FlowSessionState.Initiated
|
||||||
if (initiatedState != null) {
|
if (initiatedState != null) {
|
||||||
sendSessionMessage(initiatedState.peerParty, SessionEnd(initiatedState.peerSessionId), psm)
|
sendSessionMessage(initiatedState.peerParty, SessionEnd(initiatedState.peerSessionId), fiber)
|
||||||
recentlyClosedSessions[session.ourSessionId] = initiatedState.peerParty
|
recentlyClosedSessions[session.ourSessionId] = initiatedState.peerParty
|
||||||
}
|
}
|
||||||
true
|
true
|
||||||
@ -405,10 +413,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
return fiber
|
return fiber
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun updateCheckpoint(psm: FlowStateMachineImpl<*>) {
|
private fun updateCheckpoint(fiber: FlowStateMachineImpl<*>) {
|
||||||
check(psm.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" }
|
check(fiber.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" }
|
||||||
val newCheckpoint = Checkpoint(serializeFiber(psm))
|
val newCheckpoint = Checkpoint(serializeFiber(fiber))
|
||||||
val previousCheckpoint = mutex.locked { stateMachines.put(psm, newCheckpoint) }
|
val previousCheckpoint = mutex.locked { stateMachines.put(fiber, newCheckpoint) }
|
||||||
if (previousCheckpoint != null) {
|
if (previousCheckpoint != null) {
|
||||||
checkpointStorage.removeCheckpoint(previousCheckpoint)
|
checkpointStorage.removeCheckpoint(previousCheckpoint)
|
||||||
}
|
}
|
||||||
@ -416,13 +424,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
checkpointingMeter.mark()
|
checkpointingMeter.mark()
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun resumeFiber(psm: FlowStateMachineImpl<*>) {
|
private fun resumeFiber(fiber: FlowStateMachineImpl<*>) {
|
||||||
// Avoid race condition when setting stopping to true and then checking liveFibers
|
// Avoid race condition when setting stopping to true and then checking liveFibers
|
||||||
incrementLiveFibers()
|
incrementLiveFibers()
|
||||||
if (!stopping) executor.executeASAP {
|
if (!stopping) executor.executeASAP {
|
||||||
psm.resume(scheduler)
|
fiber.resume(scheduler)
|
||||||
} else {
|
} else {
|
||||||
psm.logger.debug("Not resuming as SMM is stopping.")
|
fiber.logger.debug("Not resuming as SMM is stopping.")
|
||||||
decrementLiveFibers()
|
decrementLiveFibers()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -432,51 +440,23 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
if (ioRequest.message is SessionInit) {
|
if (ioRequest.message is SessionInit) {
|
||||||
openSessions[ioRequest.session.ourSessionId] = ioRequest.session
|
openSessions[ioRequest.session.ourSessionId] = ioRequest.session
|
||||||
}
|
}
|
||||||
sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.psm)
|
sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber)
|
||||||
if (ioRequest !is ReceiveRequest<*>) {
|
if (ioRequest !is ReceiveRequest<*>) {
|
||||||
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
|
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
|
||||||
resumeFiber(ioRequest.session.psm)
|
resumeFiber(ioRequest.session.fiber)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun sendSessionMessage(party: Party, message: SessionMessage, psm: FlowStateMachineImpl<*>?) {
|
private fun sendSessionMessage(party: Party, message: SessionMessage, fiber: FlowStateMachineImpl<*>? = null) {
|
||||||
val partyInfo = serviceHub.networkMapCache.getPartyInfo(party)
|
val partyInfo = serviceHub.networkMapCache.getPartyInfo(party)
|
||||||
?: throw IllegalArgumentException("Don't know about party $party")
|
?: throw IllegalArgumentException("Don't know about party $party")
|
||||||
val address = serviceHub.networkService.getAddressOfParty(partyInfo)
|
val address = serviceHub.networkService.getAddressOfParty(partyInfo)
|
||||||
val logger = psm?.logger ?: logger
|
val logger = fiber?.logger ?: logger
|
||||||
logger.debug { "Sending $message to party $party, address: $address" }
|
logger.debug { "Sending $message to party $party, address: $address" }
|
||||||
serviceHub.networkService.send(sessionTopic, message, address)
|
serviceHub.networkService.send(sessionTopic, message, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
data class ReceivedSessionMessage<out M : SessionMessage>(val sendingParty: Party, val message: M)
|
|
||||||
|
|
||||||
interface SessionMessage
|
|
||||||
|
|
||||||
interface ExistingSessionMessage : SessionMessage {
|
|
||||||
val recipientSessionId: Long
|
|
||||||
}
|
|
||||||
|
|
||||||
data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage
|
|
||||||
|
|
||||||
interface SessionInitResponse : ExistingSessionMessage
|
|
||||||
|
|
||||||
data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse {
|
|
||||||
override val recipientSessionId: Long get() = initiatorSessionId
|
|
||||||
}
|
|
||||||
|
|
||||||
data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse {
|
|
||||||
override val recipientSessionId: Long get() = initiatorSessionId
|
|
||||||
}
|
|
||||||
|
|
||||||
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage {
|
|
||||||
override fun toString(): String {
|
|
||||||
return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* [FlowSessionState] describes the session's state.
|
* [FlowSessionState] describes the session's state.
|
||||||
*
|
*
|
||||||
@ -507,7 +487,6 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
@Volatile var waitingForResponse: Boolean = false
|
@Volatile var waitingForResponse: Boolean = false
|
||||||
) {
|
) {
|
||||||
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<ExistingSessionMessage>>()
|
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<ExistingSessionMessage>>()
|
||||||
val psm: FlowStateMachineImpl<*> get() = flow.fsm as FlowStateMachineImpl<*>
|
val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*>
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -415,10 +415,10 @@ class TwoPartyTradeFlowTests {
|
|||||||
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>): RunResult {
|
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>): RunResult {
|
||||||
val buyerFuture = bobNode.initiateSingleShotFlow(Seller::class) { otherParty ->
|
val buyerFuture = bobNode.initiateSingleShotFlow(Seller::class) { otherParty ->
|
||||||
Buyer(otherParty, notaryNode.info.notaryIdentity, 1000.DOLLARS, CommercialPaper.State::class.java)
|
Buyer(otherParty, notaryNode.info.notaryIdentity, 1000.DOLLARS, CommercialPaper.State::class.java)
|
||||||
}.map { it.fsm }
|
}.map { it.stateMachine }
|
||||||
val seller = Seller(bobNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
|
val seller = Seller(bobNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
|
||||||
val sellerResultFuture = aliceNode.services.startFlow(seller).resultFuture
|
val sellerResultFuture = aliceNode.services.startFlow(seller).resultFuture
|
||||||
return RunResult(buyerFuture, sellerResultFuture, seller.fsm.id)
|
return RunResult(buyerFuture, sellerResultFuture, seller.stateMachine.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError(
|
private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError(
|
||||||
|
@ -7,8 +7,8 @@ import net.corda.core.contracts.DOLLARS
|
|||||||
import net.corda.core.contracts.issuedBy
|
import net.corda.core.contracts.issuedBy
|
||||||
import net.corda.core.crypto.Party
|
import net.corda.core.crypto.Party
|
||||||
import net.corda.core.crypto.generateKeyPair
|
import net.corda.core.crypto.generateKeyPair
|
||||||
|
import net.corda.core.flows.FlowException
|
||||||
import net.corda.core.flows.FlowLogic
|
import net.corda.core.flows.FlowLogic
|
||||||
import net.corda.core.flows.FlowSessionException
|
|
||||||
import net.corda.core.getOrThrow
|
import net.corda.core.getOrThrow
|
||||||
import net.corda.core.random63BitValue
|
import net.corda.core.random63BitValue
|
||||||
import net.corda.core.serialization.OpaqueBytes
|
import net.corda.core.serialization.OpaqueBytes
|
||||||
@ -17,7 +17,6 @@ import net.corda.flows.CashCommand
|
|||||||
import net.corda.flows.CashFlow
|
import net.corda.flows.CashFlow
|
||||||
import net.corda.flows.NotaryFlow
|
import net.corda.flows.NotaryFlow
|
||||||
import net.corda.node.services.persistence.checkpoints
|
import net.corda.node.services.persistence.checkpoints
|
||||||
import net.corda.node.services.statemachine.StateMachineManager.*
|
|
||||||
import net.corda.node.utilities.databaseTransaction
|
import net.corda.node.utilities.databaseTransaction
|
||||||
import net.corda.testing.expect
|
import net.corda.testing.expect
|
||||||
import net.corda.testing.expectEvents
|
import net.corda.testing.expectEvents
|
||||||
@ -215,15 +214,15 @@ class StateMachineManagerTests {
|
|||||||
|
|
||||||
assertSessionTransfers(node2,
|
assertSessionTransfers(node2,
|
||||||
node1 sent sessionInit(SendFlow::class, payload) to node2,
|
node1 sent sessionInit(SendFlow::class, payload) to node2,
|
||||||
node2 sent sessionConfirm() to node1,
|
node2 sent sessionConfirm to node1,
|
||||||
node1 sent sessionEnd() to node2
|
node1 sent sessionEnd to node2
|
||||||
//There's no session end from the other flows as they're manually suspended
|
//There's no session end from the other flows as they're manually suspended
|
||||||
)
|
)
|
||||||
|
|
||||||
assertSessionTransfers(node3,
|
assertSessionTransfers(node3,
|
||||||
node1 sent sessionInit(SendFlow::class, payload) to node3,
|
node1 sent sessionInit(SendFlow::class, payload) to node3,
|
||||||
node3 sent sessionConfirm() to node1,
|
node3 sent sessionConfirm to node1,
|
||||||
node1 sent sessionEnd() to node3
|
node1 sent sessionEnd to node3
|
||||||
//There's no session end from the other flows as they're manually suspended
|
//There's no session end from the other flows as they're manually suspended
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -248,16 +247,16 @@ class StateMachineManagerTests {
|
|||||||
|
|
||||||
assertSessionTransfers(node2,
|
assertSessionTransfers(node2,
|
||||||
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
|
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
|
||||||
node2 sent sessionConfirm() to node1,
|
node2 sent sessionConfirm to node1,
|
||||||
node2 sent sessionData(node2Payload) to node1,
|
node2 sent sessionData(node2Payload) to node1,
|
||||||
node2 sent sessionEnd() to node1
|
node2 sent sessionEnd to node1
|
||||||
)
|
)
|
||||||
|
|
||||||
assertSessionTransfers(node3,
|
assertSessionTransfers(node3,
|
||||||
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node3,
|
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node3,
|
||||||
node3 sent sessionConfirm() to node1,
|
node3 sent sessionConfirm to node1,
|
||||||
node3 sent sessionData(node3Payload) to node1,
|
node3 sent sessionData(node3Payload) to node1,
|
||||||
node3 sent sessionEnd() to node1
|
node3 sent sessionEnd to node1
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -269,11 +268,11 @@ class StateMachineManagerTests {
|
|||||||
|
|
||||||
assertSessionTransfers(
|
assertSessionTransfers(
|
||||||
node1 sent sessionInit(PingPongFlow::class, 10L) to node2,
|
node1 sent sessionInit(PingPongFlow::class, 10L) to node2,
|
||||||
node2 sent sessionConfirm() to node1,
|
node2 sent sessionConfirm to node1,
|
||||||
node2 sent sessionData(20L) to node1,
|
node2 sent sessionData(20L) to node1,
|
||||||
node1 sent sessionData(11L) to node2,
|
node1 sent sessionData(11L) to node2,
|
||||||
node2 sent sessionData(21L) to node1,
|
node2 sent sessionData(21L) to node1,
|
||||||
node1 sent sessionEnd() to node2
|
node1 sent sessionEnd to node2
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -333,11 +332,11 @@ class StateMachineManagerTests {
|
|||||||
node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { ExceptionFlow }
|
node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { ExceptionFlow }
|
||||||
val future = node1.services.startFlow(ReceiveThenSuspendFlow(node2.info.legalIdentity)).resultFuture
|
val future = node1.services.startFlow(ReceiveThenSuspendFlow(node2.info.legalIdentity)).resultFuture
|
||||||
net.runNetwork()
|
net.runNetwork()
|
||||||
assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowSessionException::class.java)
|
assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowException::class.java)
|
||||||
assertSessionTransfers(
|
assertSessionTransfers(
|
||||||
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
|
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
|
||||||
node2 sent sessionConfirm() to node1,
|
node2 sent sessionConfirm to node1,
|
||||||
node2 sent sessionEnd() to node1
|
node2 sent sessionEnd to node1
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -358,11 +357,11 @@ class StateMachineManagerTests {
|
|||||||
|
|
||||||
private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload)
|
private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload)
|
||||||
|
|
||||||
private fun sessionConfirm() = SessionConfirm(0, 0)
|
private val sessionConfirm = SessionConfirm(0, 0)
|
||||||
|
|
||||||
private fun sessionData(payload: Any) = SessionData(0, payload)
|
private fun sessionData(payload: Any) = SessionData(0, payload)
|
||||||
|
|
||||||
private fun sessionEnd() = SessionEnd(0)
|
private val sessionEnd = SessionEnd(0)
|
||||||
|
|
||||||
private fun assertSessionTransfers(vararg expected: SessionTransfer) {
|
private fun assertSessionTransfers(vararg expected: SessionTransfer) {
|
||||||
assertThat(sessionTransfers).containsExactly(*expected)
|
assertThat(sessionTransfers).containsExactly(*expected)
|
||||||
@ -462,5 +461,4 @@ class StateMachineManagerTests {
|
|||||||
private object ExceptionFlow : FlowLogic<Nothing>() {
|
private object ExceptionFlow : FlowLogic<Nothing>() {
|
||||||
override fun call(): Nothing = throw Exception()
|
override fun call(): Nothing = throw Exception()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -120,7 +120,7 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten
|
|||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
val acceptorTx = node2.initiateSingleShotFlow(Instigator::class) { Acceptor(it) }.flatMap {
|
val acceptorTx = node2.initiateSingleShotFlow(Instigator::class) { Acceptor(it) }.flatMap {
|
||||||
(it.fsm as FlowStateMachine<SignedTransaction>).resultFuture
|
(it.stateMachine as FlowStateMachine<SignedTransaction>).resultFuture
|
||||||
}
|
}
|
||||||
|
|
||||||
showProgressFor(listOf(node1, node2))
|
showProgressFor(listOf(node1, node2))
|
||||||
|
@ -53,7 +53,7 @@ class TradeSimulation(runAsync: Boolean, latencyInjector: InMemoryMessagingNetwo
|
|||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
val buyerFuture = buyer.initiateSingleShotFlow(Seller::class) {
|
val buyerFuture = buyer.initiateSingleShotFlow(Seller::class) {
|
||||||
Buyer(it, notary.info.notaryIdentity, amount, CommercialPaper.State::class.java)
|
Buyer(it, notary.info.notaryIdentity, amount, CommercialPaper.State::class.java)
|
||||||
}.flatMap { (it.fsm as FlowStateMachine<SignedTransaction>).resultFuture }
|
}.flatMap { (it.stateMachine as FlowStateMachine<SignedTransaction>).resultFuture }
|
||||||
|
|
||||||
val sellerKey = seller.services.legalIdentityKey
|
val sellerKey = seller.services.legalIdentityKey
|
||||||
val sellerFlow = Seller(
|
val sellerFlow = Seller(
|
||||||
|
@ -17,7 +17,9 @@ import net.corda.core.then
|
|||||||
import net.corda.core.utilities.ProgressTracker
|
import net.corda.core.utilities.ProgressTracker
|
||||||
import net.corda.netmap.VisualiserViewModel.Style
|
import net.corda.netmap.VisualiserViewModel.Style
|
||||||
import net.corda.node.services.network.NetworkMapService
|
import net.corda.node.services.network.NetworkMapService
|
||||||
import net.corda.node.services.statemachine.StateMachineManager
|
import net.corda.node.services.statemachine.SessionConfirm
|
||||||
|
import net.corda.node.services.statemachine.SessionEnd
|
||||||
|
import net.corda.node.services.statemachine.SessionInit
|
||||||
import net.corda.simulation.IRSSimulation
|
import net.corda.simulation.IRSSimulation
|
||||||
import net.corda.simulation.Simulation
|
import net.corda.simulation.Simulation
|
||||||
import net.corda.testing.node.InMemoryMessagingNetwork
|
import net.corda.testing.node.InMemoryMessagingNetwork
|
||||||
@ -349,13 +351,12 @@ class NetworkMapVisualiser : Application() {
|
|||||||
// Network map push acknowledgements are boring.
|
// Network map push acknowledgements are boring.
|
||||||
if (NetworkMapService.PUSH_ACK_FLOW_TOPIC in transfer.message.topicSession.topic) return false
|
if (NetworkMapService.PUSH_ACK_FLOW_TOPIC in transfer.message.topicSession.topic) return false
|
||||||
val message = transfer.message.data.deserialize<Any>()
|
val message = transfer.message.data.deserialize<Any>()
|
||||||
val messageClassType = message.javaClass.name
|
return when (message) {
|
||||||
when (messageClassType) {
|
is SessionEnd -> false
|
||||||
StateMachineManager.SessionEnd::class.java.name -> return false
|
is SessionConfirm -> false
|
||||||
StateMachineManager.SessionConfirm::class.java.name -> return false
|
is SessionInit -> message.firstPayload != null
|
||||||
StateMachineManager.SessionInit::class.java.name -> if ((message as StateMachineManager.SessionInit).firstPayload == null) return false
|
else -> true
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user