End flow if waiting for ledger commit and committer flow errors

This commit is contained in:
Shams Asari 2017-02-08 16:00:39 +00:00 committed by Chris Rankin
parent 3c0d6fd14f
commit 71182ec8c1
7 changed files with 244 additions and 152 deletions

View File

@ -203,7 +203,7 @@ abstract class FlowLogic<out T> {
val theirs = subLogic.progressTracker val theirs = subLogic.progressTracker
if (ours != null && theirs != null) { if (ours != null && theirs != null) {
if (ours.currentStep == ProgressTracker.UNSTARTED) { if (ours.currentStep == ProgressTracker.UNSTARTED) {
logger.warn("ProgressTracker has not been started for $this") logger.warn("ProgressTracker has not been started")
ours.nextStep() ours.nextStep()
} }
ours.setChildProgressTracker(ours.currentStep, theirs) ours.setChildProgressTracker(ours.currentStep, theirs)

View File

@ -10,6 +10,8 @@ interface FlowIORequest {
val stackTraceInCaseOfProblems: StackSnapshot val stackTraceInCaseOfProblems: StackSnapshot
} }
interface WaitingRequest : FlowIORequest
interface SessionedFlowIORequest : FlowIORequest { interface SessionedFlowIORequest : FlowIORequest {
val session: FlowSession val session: FlowSession
} }
@ -18,7 +20,7 @@ interface SendRequest : SessionedFlowIORequest {
val message: SessionMessage val message: SessionMessage
} }
interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest { interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest, WaitingRequest {
val receiveType: Class<T> val receiveType: Class<T>
} }
@ -40,7 +42,7 @@ data class SendOnly(override val session: FlowSession, override val message: Ses
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
} }
data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : FlowIORequest { data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachineImpl<*>) : WaitingRequest {
@Transient @Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
} }

View File

@ -51,7 +51,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Transient override lateinit var serviceHub: ServiceHubInternal @Transient override lateinit var serviceHub: ServiceHubInternal
@Transient internal lateinit var database: Database @Transient internal lateinit var database: Database
@Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit @Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit
@Transient internal lateinit var actionOnEnd: (Pair<FlowException, Boolean>?) -> Unit @Transient internal lateinit var actionOnEnd: (Throwable?, Boolean) -> Unit
@Transient internal var fromCheckpoint: Boolean = false @Transient internal var fromCheckpoint: Boolean = false
@Transient private var txTrampoline: Transaction? = null @Transient private var txTrampoline: Transaction? = null
@ -76,7 +76,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
// This state IS serialised, as we need it to know what the fiber is waiting for. // This state IS serialised, as we need it to know what the fiber is waiting for.
internal val openSessions = HashMap<Pair<FlowLogic<*>, Party>, FlowSession>() internal val openSessions = HashMap<Pair<FlowLogic<*>, Party>, FlowSession>()
internal var waitingForLedgerCommitOf: SecureHash? = null internal var waitingForResponse: WaitingRequest? = null
init { init {
logic.stateMachine = this logic.stateMachine = this
@ -91,11 +91,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} catch (e: FlowException) { } catch (e: FlowException) {
// Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive). // Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive).
val propagated = e.stackTrace[0].className == javaClass.name val propagated = e.stackTrace[0].className == javaClass.name
actionOnEnd(Pair(e, propagated)) actionOnEnd(e, propagated)
_resultFuture?.setException(e) _resultFuture?.setException(e)
return return
} catch (t: Throwable) { } catch (t: Throwable) {
actionOnEnd(null) actionOnEnd(t, false)
_resultFuture?.setException(t) _resultFuture?.setException(t)
throw ExecutionException(t) throw ExecutionException(t)
} }
@ -105,7 +105,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
.filter { it.state is FlowSessionState.Initiating } .filter { it.state is FlowSessionState.Initiating }
.forEach { it.waitForConfirmation() } .forEach { it.waitForConfirmation() }
// This is to prevent actionOnEnd being called twice if it throws an exception // This is to prevent actionOnEnd being called twice if it throws an exception
actionOnEnd(null) actionOnEnd(null, false)
_resultFuture?.set(result) _resultFuture?.set(result)
} }
@ -136,10 +136,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
sessionFlow: FlowLogic<*>): UntrustworthyData<T> { sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
val session = getConfirmedSession(otherParty, sessionFlow) val session = getConfirmedSession(otherParty, sessionFlow)
return if (session == null) { return if (session == null) {
val newSession = startNewSession(otherParty, sessionFlow, payload, waitForConfirmation = true)
// Only do a receive here as the session init has carried the payload // Only do a receive here as the session init has carried the payload
receiveInternal<SessionData>(startNewSession(otherParty, sessionFlow, payload, waitForConfirmation = true)) receiveInternal<SessionData>(newSession, receiveType)
} else { } else {
sendAndReceiveInternal<SessionData>(session, createSessionData(session, payload)) sendAndReceiveInternal<SessionData>(session, createSessionData(session, payload), receiveType)
}.checkPayloadIs(receiveType) }.checkPayloadIs(receiveType)
} }
@ -147,8 +148,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
override fun <T : Any> receive(receiveType: Class<T>, override fun <T : Any> receive(receiveType: Class<T>,
otherParty: Party, otherParty: Party,
sessionFlow: FlowLogic<*>): UntrustworthyData<T> { sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
val session = getConfirmedSession(otherParty, sessionFlow) ?: startNewSession(otherParty, sessionFlow, null, waitForConfirmation = true) val session = getConfirmedSession(otherParty, sessionFlow) ?:
return receiveInternal<SessionData>(session).checkPayloadIs(receiveType) startNewSession(otherParty, sessionFlow, null, waitForConfirmation = true)
return receiveInternal<SessionData>(session, receiveType).checkPayloadIs(receiveType)
} }
@Suspendable @Suspendable
@ -167,7 +169,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
*/ */
@Suspendable @Suspendable
private fun FlowSession.waitForConfirmation() { private fun FlowSession.waitForConfirmation() {
val (peerParty, sessionInitResponse) = receiveInternal<SessionInitResponse>(this) val (peerParty, sessionInitResponse) = receiveInternal<SessionInitResponse>(this, null)
if (sessionInitResponse is SessionConfirm) { if (sessionInitResponse is SessionConfirm) {
state = FlowSessionState.Initiated(peerParty, sessionInitResponse.initiatedSessionId) state = FlowSessionState.Initiated(peerParty, sessionInitResponse.initiatedSessionId)
} else { } else {
@ -178,12 +180,19 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable @Suspendable
override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction { override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction {
waitingForLedgerCommitOf = hash
logger.info("Waiting for transaction $hash to commit") logger.info("Waiting for transaction $hash to commit")
suspend(WaitForLedgerCommit(hash, sessionFlow.stateMachine as FlowStateMachineImpl<*>)) suspend(WaitForLedgerCommit(hash, sessionFlow.stateMachine as FlowStateMachineImpl<*>))
logger.info("Transaction $hash has committed to the ledger, resuming")
val stx = serviceHub.storageService.validatedTransactions.getTransaction(hash) val stx = serviceHub.storageService.validatedTransactions.getTransaction(hash)
return stx ?: throw IllegalStateException("We were resumed after waiting for $hash but it wasn't found in our local storage") if (stx != null) return stx
// If the tx isn't committed then we may have been resumed due to an session ending in an error
for (session in openSessions.values) {
for (receivedMessage in session.receivedMessages) {
if (receivedMessage.message is ErrorSessionEnd) {
session.erroredEnd(receivedMessage.message)
}
}
}
throw IllegalStateException("We were resumed after waiting for $hash but it wasn't found in our local storage")
} }
private fun createSessionData(session: FlowSession, payload: Any): SessionData { private fun createSessionData(session: FlowSession, payload: Any): SessionData {
@ -200,14 +209,17 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
suspend(SendOnly(session, message)) suspend(SendOnly(session, message))
} }
private inline fun <reified M : ExistingSessionMessage> receiveInternal(session: FlowSession): ReceivedSessionMessage<M> { private inline fun <reified M : ExistingSessionMessage> receiveInternal(
return suspendAndExpectReceive(ReceiveOnly(session, M::class.java)) session: FlowSession,
userReceiveType: Class<*>?): ReceivedSessionMessage<M> {
return waitForMessage(ReceiveOnly(session, M::class.java), userReceiveType)
} }
private inline fun <reified M : ExistingSessionMessage> sendAndReceiveInternal( private inline fun <reified M : ExistingSessionMessage> sendAndReceiveInternal(
session: FlowSession, session: FlowSession,
message: SessionMessage): ReceivedSessionMessage<M> { message: SessionMessage,
return suspendAndExpectReceive(SendAndReceive(session, message, M::class.java)) userReceiveType: Class<*>?): ReceivedSessionMessage<M> {
return waitForMessage(SendAndReceive(session, message, M::class.java), userReceiveType)
} }
@Suspendable @Suspendable
@ -241,51 +253,72 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} }
@Suspendable @Suspendable
@Suppress("UNCHECKED_CAST", "PLATFORM_CLASS_MAPPED_TO_KOTLIN") private fun <M : ExistingSessionMessage> waitForMessage(
private fun <M : ExistingSessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> { receiveRequest: ReceiveRequest<M>,
val session = receiveRequest.session userReceiveType: Class<*>?): ReceivedSessionMessage<M> {
fun getReceivedMessage(): ReceivedSessionMessage<ExistingSessionMessage>? = session.receivedMessages.poll() val receivedMessage = receiveRequest.suspendAndExpectReceive()
return receivedMessage.confirmReceiveType(receiveRequest, userReceiveType)
}
val polledMessage = getReceivedMessage() @Suspendable
val receivedMessage = if (polledMessage != null) { private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> {
if (receiveRequest is SendAndReceive) { fun pollForMessage() = session.receivedMessages.poll()
val polledMessage = pollForMessage()
return if (polledMessage != null) {
if (this is SendAndReceive) {
// We've already received a message but we suspend so that the send can be performed // We've already received a message but we suspend so that the send can be performed
suspend(receiveRequest) suspend(this)
} }
polledMessage polledMessage
} else { } else {
// Suspend while we wait for a receive // Suspend while we wait for a receive
suspend(receiveRequest) suspend(this)
getReceivedMessage() ?: pollForMessage() ?:
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead " + throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got nothing for $this")
"got nothing for $receiveRequest")
}
if (receiveRequest.receiveType.isInstance(receivedMessage.message)) {
return receivedMessage as ReceivedSessionMessage<M>
} else if (receivedMessage.message is SessionEnd) {
openSessions.values.remove(session)
if (receivedMessage.message.errorResponse != null) {
(receivedMessage.message.errorResponse as java.lang.Throwable).fillInStackTrace()
throw receivedMessage.message.errorResponse
} else {
throw FlowSessionException("${session.state.sendToParty} has ended their flow but we were expecting " +
"to receive ${receiveRequest.receiveType.simpleName} from them")
}
} else {
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but instead got " +
"${receivedMessage.message} for $receiveRequest")
} }
} }
private fun <M : ExistingSessionMessage> ReceivedSessionMessage<*>.confirmReceiveType(
receiveRequest: ReceiveRequest<M>,
userReceiveType: Class<*>?): ReceivedSessionMessage<M> {
val session = receiveRequest.session
val receiveType = receiveRequest.receiveType
if (receiveType.isInstance(message)) {
@Suppress("UNCHECKED_CAST")
return this as ReceivedSessionMessage<M>
} else if (message is SessionEnd) {
openSessions.values.remove(session)
if (message is ErrorSessionEnd) {
session.erroredEnd(message)
} else {
val expectedType = userReceiveType?.name ?: receiveType.simpleName
throw FlowSessionException("Counterparty flow on ${session.state.sendToParty} has completed without " +
"sending a $expectedType")
}
} else {
throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got $message for $receiveRequest")
}
}
private fun FlowSession.erroredEnd(end: ErrorSessionEnd): Nothing {
if (end.errorResponse != null) {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
(end.errorResponse as java.lang.Throwable).fillInStackTrace()
throw end.errorResponse
} else {
throw FlowSessionException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated")
}
}
@Suspendable @Suspendable
private fun suspend(ioRequest: FlowIORequest) { private fun suspend(ioRequest: FlowIORequest) {
// We have to pass the thread local database transaction across via a transient field as the fiber park // We have to pass the thread local database transaction across via a transient field as the fiber park
// swaps them out. // swaps them out.
txTrampoline = TransactionManager.currentOrNull() txTrampoline = TransactionManager.currentOrNull()
StrandLocalTransactionManager.setThreadLocalTx(null) StrandLocalTransactionManager.setThreadLocalTx(null)
if (ioRequest is SessionedFlowIORequest) if (ioRequest is WaitingRequest)
ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>) waitingForResponse = ioRequest
var exceptionDuringSuspend: Throwable? = null var exceptionDuringSuspend: Throwable? = null
parkAndSerialize { fiber, serializer -> parkAndSerialize { fiber, serializer ->

View File

@ -7,20 +7,10 @@ import net.corda.core.utilities.UntrustworthyData
interface SessionMessage interface SessionMessage
interface ExistingSessionMessage : SessionMessage {
val recipientSessionId: Long
}
data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage data class SessionInit(val initiatorSessionId: Long, val flowName: String, val firstPayload: Any?) : SessionMessage
interface SessionInitResponse : ExistingSessionMessage interface ExistingSessionMessage : SessionMessage {
val recipientSessionId: Long
data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId
}
data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId
} }
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage { data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage {
@ -29,7 +19,16 @@ data class SessionData(override val recipientSessionId: Long, val payload: Any)
} }
} }
data class SessionEnd(override val recipientSessionId: Long, val errorResponse: FlowException?) : ExistingSessionMessage interface SessionInitResponse : ExistingSessionMessage {
val initiatorSessionId: Long
override val recipientSessionId: Long get() = initiatorSessionId
}
data class SessionConfirm(override val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse
data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse
interface SessionEnd : ExistingSessionMessage
data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd
data class ErrorSessionEnd(override val recipientSessionId: Long, val errorResponse: FlowException?) : SessionEnd
data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M) data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M)

View File

@ -164,13 +164,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
// Observe the stream of committed, validated transactions and resume fibers that are waiting for them. // Observe the stream of committed, validated transactions and resume fibers that are waiting for them.
serviceHub.storageService.validatedTransactions.updates.subscribe { stx -> serviceHub.storageService.validatedTransactions.updates.subscribe { stx ->
val hash = stx.id val hash = stx.id
val flows: Set<FlowStateMachineImpl<*>> = mutex.locked { fibersWaitingForLedgerCommit.removeAll(hash) } val fibers: Set<FlowStateMachineImpl<*>> = mutex.locked { fibersWaitingForLedgerCommit.removeAll(hash) }
if (flows.isNotEmpty()) { if (fibers.isNotEmpty()) {
executor.executeASAP { executor.executeASAP {
for (flow in flows) { for (fiber in fibers) {
logger.info("Resuming ${flow.id} because it was waiting for tx ${flow.waitingForLedgerCommitOf!!} which is now committed.") fiber.logger.info("Transaction $hash has committed to the ledger, resuming")
flow.waitingForLedgerCommitOf = null fiber.waitingForResponse = null
resumeFiber(flow) resumeFiber(fiber)
} }
} }
} }
@ -239,19 +239,22 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
private fun resumeRestoredFiber(fiber: FlowStateMachineImpl<*>) { private fun resumeRestoredFiber(fiber: FlowStateMachineImpl<*>) {
fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it } fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it }
val waitingForHash = fiber.waitingForLedgerCommitOf val waitingForResponse = fiber.waitingForResponse
if (fiber.openSessions.values.any { it.waitingForResponse }) { if (waitingForResponse != null) {
fiber.logger.info("Restored, pending on receive") if (waitingForResponse is WaitForLedgerCommit) {
} else if (waitingForHash != null) { val stx = databaseTransaction(database) {
val stx = databaseTransaction(database) { serviceHub.storageService.validatedTransactions.getTransaction(waitingForResponse.hash)
serviceHub.storageService.validatedTransactions.getTransaction(waitingForHash) }
} if (stx != null) {
if (stx != null) { fiber.logger.info("Resuming fiber as tx ${waitingForResponse.hash} has committed.")
fiber.logger.info("Resuming fiber as tx $waitingForHash has committed.") fiber.waitingForResponse = null
resumeFiber(fiber) resumeFiber(fiber)
} else {
fiber.logger.info("Restored, pending on ledger commit of ${waitingForResponse.hash}")
mutex.locked { fibersWaitingForLedgerCommit.put(waitingForResponse.hash, fiber) }
}
} else { } else {
fiber.logger.info("Restored, pending on ledger commit of $waitingForHash") fiber.logger.info("Restored, pending on receive")
mutex.locked { fibersWaitingForLedgerCommit.put(waitingForHash, fiber) }
} }
} else { } else {
resumeFiber(fiber) resumeFiber(fiber)
@ -275,15 +278,17 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) { private fun onExistingSessionMessage(message: ExistingSessionMessage, sender: Party) {
val session = openSessions[message.recipientSessionId] val session = openSessions[message.recipientSessionId]
if (session != null) { if (session != null) {
session.fiber.logger.trace { "Received $message on $session" } session.fiber.logger.trace { "Received $message on $session from $sender" }
if (message is SessionEnd) { if (message is SessionEnd) {
openSessions.remove(message.recipientSessionId) openSessions.remove(message.recipientSessionId)
} }
session.receivedMessages += ReceivedSessionMessage(sender, message) session.receivedMessages += ReceivedSessionMessage(sender, message)
if (session.waitingForResponse) { if (resumeOnMessage(message, session)) {
// We only want to resume once, so immediately reset the flag. // It's important that we reset here and not after the fiber's resumed, in case we receive another message
session.waitingForResponse = false // before then.
session.fiber.waitingForResponse = null
updateCheckpoint(session.fiber) updateCheckpoint(session.fiber)
session.fiber.logger.debug { "About to resume due to $message" }
resumeFiber(session.fiber) resumeFiber(session.fiber)
} }
} else { } else {
@ -291,7 +296,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
if (peerParty != null) { if (peerParty != null) {
if (message is SessionConfirm) { if (message is SessionConfirm) {
logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" } logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" }
sendSessionMessage(peerParty, SessionEnd(message.initiatedSessionId, null)) sendSessionMessage(peerParty, NormalSessionEnd(message.initiatedSessionId))
} else { } else {
logger.trace { "Ignoring session end message for already closed session: $message" } logger.trace { "Ignoring session end message for already closed session: $message" }
} }
@ -301,6 +306,14 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
} }
// We resume the fiber if it's received a response for which it was waiting for or it's waiting for a ledger
// commit but a counterparty flow has ended with an error (in which case our flow also has to end)
private fun resumeOnMessage(message: ExistingSessionMessage, session: FlowSession): Boolean {
val waitingForResponse = session.fiber.waitingForResponse
return (waitingForResponse as? ReceiveRequest<*>)?.session === session ||
waitingForResponse is WaitForLedgerCommit && message is ErrorSessionEnd
}
private fun onSessionInit(sessionInit: SessionInit, sender: Party) { private fun onSessionInit(sessionInit: SessionInit, sender: Party) {
logger.trace { "Received $sessionInit $sender" } logger.trace { "Received $sessionInit $sender" }
val otherPartySessionId = sessionInit.initiatorSessionId val otherPartySessionId = sessionInit.initiatorSessionId
@ -379,14 +392,14 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
processIORequest(ioRequest) processIORequest(ioRequest)
decrementLiveFibers() decrementLiveFibers()
} }
fiber.actionOnEnd = { errorResponse: Pair<FlowException, Boolean>? -> fiber.actionOnEnd = { exception, propagated ->
try { try {
fiber.logic.progressTracker?.currentStep = ProgressTracker.DONE fiber.logic.progressTracker?.currentStep = ProgressTracker.DONE
mutex.locked { mutex.locked {
stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) } stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) }
notifyChangeObservers(fiber, AddOrRemove.REMOVE) notifyChangeObservers(fiber, AddOrRemove.REMOVE)
} }
endAllFiberSessions(fiber, errorResponse) endAllFiberSessions(fiber, exception, propagated)
} finally { } finally {
fiber.commitTransaction() fiber.commitTransaction()
decrementLiveFibers() decrementLiveFibers()
@ -401,10 +414,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
} }
private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, errorResponse: Pair<FlowException, Boolean>?) { private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, exception: Throwable?, propagated: Boolean) {
openSessions.values.removeIf { session -> openSessions.values.removeIf { session ->
if (session.fiber == fiber) { if (session.fiber == fiber) {
session.endSession(errorResponse) session.endSession(exception, propagated)
true true
} else { } else {
false false
@ -412,22 +425,21 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
} }
private fun FlowSession.endSession(errorResponse: Pair<FlowException, Boolean>?) { private fun FlowSession.endSession(exception: Throwable?, propagated: Boolean) {
val initiatedState = state as? Initiated ?: return val initiatedState = state as? Initiated ?: return
val propagatedException = errorResponse?.let { val sessionEnd = if (exception == null) {
val (exception, propagated) = it NormalSessionEnd(initiatedState.peerSessionId)
if (propagated) { } else {
// This exception was propagated to us. We only propagate it down the invocation chain to the flow that val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) {
// initiated us, not to flows we've started sessions with. // Only propagate this FlowException if our local flow threw it or it was propagated to us and we only
if (initiatingParty != null) exception else null // pass it down invocation chain to the flow that initiated us, not to flows we've started sessions with.
exception
} else { } else {
exception // Our local flow threw the exception so propagate it null
} }
ErrorSessionEnd(initiatedState.peerSessionId, errorResponse)
} }
sendSessionMessage( sendSessionMessage(initiatedState.peerParty, sessionEnd, fiber)
initiatedState.peerParty,
SessionEnd(initiatedState.peerSessionId, propagatedException),
fiber)
recentlyClosedSessions[ourSessionId] = initiatedState.peerParty recentlyClosedSessions[ourSessionId] = initiatedState.peerParty
} }
@ -570,10 +582,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
val flow: FlowLogic<*>, val flow: FlowLogic<*>,
val ourSessionId: Long, val ourSessionId: Long,
val initiatingParty: Party?, val initiatingParty: Party?,
var state: FlowSessionState, var state: FlowSessionState)
@Volatile var waitingForResponse: Boolean = false {
) { val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<*>>()
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<ExistingSessionMessage>>()
val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*> val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*>
} }
} }

View File

@ -8,7 +8,6 @@ import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.DummyState import net.corda.core.contracts.DummyState
import net.corda.core.contracts.issuedBy import net.corda.core.contracts.issuedBy
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.generateKeyPair import net.corda.core.crypto.generateKeyPair
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
@ -21,9 +20,9 @@ import net.corda.core.random63BitValue
import net.corda.core.rootCause import net.corda.core.rootCause
import net.corda.core.serialization.OpaqueBytes import net.corda.core.serialization.OpaqueBytes
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.utilities.unwrap
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.unwrap
import net.corda.flows.CashIssueFlow import net.corda.flows.CashIssueFlow
import net.corda.flows.CashPaymentFlow import net.corda.flows.CashPaymentFlow
import net.corda.flows.FinalityFlow import net.corda.flows.FinalityFlow
@ -36,6 +35,7 @@ import net.corda.testing.expectEvents
import net.corda.testing.initiateSingleShotFlow import net.corda.testing.initiateSingleShotFlow
import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.InMemoryMessagingNetwork.MessageTransfer import net.corda.testing.node.InMemoryMessagingNetwork.MessageTransfer
import net.corda.testing.node.InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin
import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork
import net.corda.testing.node.MockNetwork.MockNode import net.corda.testing.node.MockNetwork.MockNode
import net.corda.testing.sequence import net.corda.testing.sequence
@ -49,10 +49,11 @@ import rx.Observable
import java.util.* import java.util.*
import kotlin.reflect.KClass import kotlin.reflect.KClass
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue import kotlin.test.assertTrue
class StateMachineManagerTests { class StateMachineManagerTests {
private val net = MockNetwork(servicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin()) private val net = MockNetwork(servicePeerAllocationStrategy = RoundRobin())
private val sessionTransfers = ArrayList<SessionTransfer>() private val sessionTransfers = ArrayList<SessionTransfer>()
private lateinit var node1: MockNode private lateinit var node1: MockNode
private lateinit var node2: MockNode private lateinit var node2: MockNode
@ -102,7 +103,7 @@ class StateMachineManagerTests {
@Test @Test
fun `exception while fiber suspended`() { fun `exception while fiber suspended`() {
node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(2, it) } node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow("Hello", it) }
val flow = ReceiveFlow(node2.info.legalIdentity) val flow = ReceiveFlow(node2.info.legalIdentity)
val fiber = node1.services.startFlow(flow) as FlowStateMachineImpl val fiber = node1.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception // Before the flow runs change the suspend action to throw an exception
@ -128,8 +129,7 @@ class StateMachineManagerTests {
@Test @Test
fun `flow restarted just after receiving payload`() { fun `flow restarted just after receiving payload`() {
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() } node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
val payload = random63BitValue() node1.services.startFlow(SendFlow("Hello", node2.info.legalIdentity))
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity))
// We push through just enough messages to get only the payload sent // We push through just enough messages to get only the payload sent
node2.pumpReceive() node2.pumpReceive()
@ -138,7 +138,7 @@ class StateMachineManagerTests {
node2.stop() node2.stop()
net.runNetwork() net.runNetwork()
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveFlow>(node1) val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveFlow>(node1)
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload) assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
} }
@Test @Test
@ -178,15 +178,14 @@ class StateMachineManagerTests {
@Test @Test
fun `flow loaded from checkpoint will respond to messages from before start`() { fun `flow loaded from checkpoint will respond to messages from before start`() {
val payload = random63BitValue() node1.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow("Hello", it) }
node1.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(payload, it) }
node2.services.startFlow(ReceiveFlow(node1.info.legalIdentity).nonTerminating()) // Prepare checkpointed receive flow node2.services.startFlow(ReceiveFlow(node1.info.legalIdentity).nonTerminating()) // Prepare checkpointed receive flow
// Make sure the add() has finished initial processing. // Make sure the add() has finished initial processing.
node2.smm.executor.flush() node2.smm.executor.flush()
node2.disableDBCloseOnStop() node2.disableDBCloseOnStop()
node2.stop() // kill receiver node2.stop() // kill receiver
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveFlow>(node1) val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveFlow>(node1)
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload) assertThat(restoredFlow.receivedPayloads[0]).isEqualTo("Hello")
} }
@Test @Test
@ -245,7 +244,7 @@ class StateMachineManagerTests {
net.runNetwork() net.runNetwork()
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() } node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
node3.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() } node3.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
val payload = random63BitValue() val payload = "Hello World"
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity)) node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity))
net.runNetwork() net.runNetwork()
val node2Flow = node2.getSingleFlow<ReceiveFlow>().first val node2Flow = node2.getSingleFlow<ReceiveFlow>().first
@ -256,14 +255,14 @@ class StateMachineManagerTests {
assertSessionTransfers(node2, assertSessionTransfers(node2,
node1 sent sessionInit(SendFlow::class, payload) to node2, node1 sent sessionInit(SendFlow::class, payload) to node2,
node2 sent sessionConfirm to node1, node2 sent sessionConfirm to node1,
node1 sent sessionEnd() to node2 node1 sent normalEnd to node2
//There's no session end from the other flows as they're manually suspended //There's no session end from the other flows as they're manually suspended
) )
assertSessionTransfers(node3, assertSessionTransfers(node3,
node1 sent sessionInit(SendFlow::class, payload) to node3, node1 sent sessionInit(SendFlow::class, payload) to node3,
node3 sent sessionConfirm to node1, node3 sent sessionConfirm to node1,
node1 sent sessionEnd() to node3 node1 sent normalEnd to node3
//There's no session end from the other flows as they're manually suspended //There's no session end from the other flows as they're manually suspended
) )
@ -275,8 +274,8 @@ class StateMachineManagerTests {
fun `receiving from multiple parties`() { fun `receiving from multiple parties`() {
val node3 = net.createNode(node1.info.address) val node3 = net.createNode(node1.info.address)
net.runNetwork() net.runNetwork()
val node2Payload = random63BitValue() val node2Payload = "Test 1"
val node3Payload = random63BitValue() val node3Payload = "Test 2"
node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node2Payload, it) } node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node2Payload, it) }
node3.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node3Payload, it) } node3.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node3Payload, it) }
val multiReceiveFlow = ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity).nonTerminating() val multiReceiveFlow = ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity).nonTerminating()
@ -290,14 +289,14 @@ class StateMachineManagerTests {
node1 sent sessionInit(ReceiveFlow::class) to node2, node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1, node2 sent sessionConfirm to node1,
node2 sent sessionData(node2Payload) to node1, node2 sent sessionData(node2Payload) to node1,
node2 sent sessionEnd() to node1 node2 sent normalEnd to node1
) )
assertSessionTransfers(node3, assertSessionTransfers(node3,
node1 sent sessionInit(ReceiveFlow::class) to node3, node1 sent sessionInit(ReceiveFlow::class) to node3,
node3 sent sessionConfirm to node1, node3 sent sessionConfirm to node1,
node3 sent sessionData(node3Payload) to node1, node3 sent sessionData(node3Payload) to node1,
node3 sent sessionEnd() to node1 node3 sent normalEnd to node1
) )
} }
@ -313,7 +312,7 @@ class StateMachineManagerTests {
node2 sent sessionData(20L) to node1, node2 sent sessionData(20L) to node1,
node1 sent sessionData(11L) to node2, node1 sent sessionData(11L) to node2,
node2 sent sessionData(21L) to node1, node2 sent sessionData(21L) to node1,
node1 sent sessionEnd() to node2 node1 sent normalEnd to node2
) )
} }
@ -321,14 +320,14 @@ class StateMachineManagerTests {
fun `different notaries are picked when addressing shared notary identity`() { fun `different notaries are picked when addressing shared notary identity`() {
assertEquals(notary1.info.notaryIdentity, notary2.info.notaryIdentity) assertEquals(notary1.info.notaryIdentity, notary2.info.notaryIdentity)
node1.services.startFlow(CashIssueFlow( node1.services.startFlow(CashIssueFlow(
DOLLARS(2000), 2000.DOLLARS,
OpaqueBytes.of(0x01), OpaqueBytes.of(0x01),
node1.info.legalIdentity, node1.info.legalIdentity,
notary1.info.notaryIdentity)) notary1.info.notaryIdentity))
// We pay a couple of times, the notary picking should go round robin // We pay a couple of times, the notary picking should go round robin
for (i in 1 .. 3) { for (i in 1 .. 3) {
node1.services.startFlow(CashPaymentFlow( node1.services.startFlow(CashPaymentFlow(
DOLLARS(500).issuedBy(node1.info.legalIdentity.ref(0x01)), 500.DOLLARS.issuedBy(node1.info.legalIdentity.ref(0x01)),
node2.info.legalIdentity)) node2.info.legalIdentity))
net.runNetwork() net.runNetwork()
} }
@ -336,7 +335,7 @@ class StateMachineManagerTests {
val party1Info = notary1.services.networkMapCache.getPartyInfo(notary1.info.notaryIdentity)!! val party1Info = notary1.services.networkMapCache.getPartyInfo(notary1.info.notaryIdentity)!!
assert(party1Info is PartyInfo.Service) assert(party1Info is PartyInfo.Service)
val notary1Address: MessageRecipients = endpoint.getAddressOfParty(notary1.services.networkMapCache.getPartyInfo(notary1.info.notaryIdentity)!!) val notary1Address: MessageRecipients = endpoint.getAddressOfParty(notary1.services.networkMapCache.getPartyInfo(notary1.info.notaryIdentity)!!)
assert(notary1Address is InMemoryMessagingNetwork.ServiceHandle) assertThat(notary1Address).isInstanceOf(InMemoryMessagingNetwork.ServiceHandle::class.java)
assertEquals(notary1Address, endpoint.getAddressOfParty(notary2.services.networkMapCache.getPartyInfo(notary2.info.notaryIdentity)!!)) assertEquals(notary1Address, endpoint.getAddressOfParty(notary2.services.networkMapCache.getPartyInfo(notary2.info.notaryIdentity)!!))
sessionTransfers.expectEvents(isStrict = false) { sessionTransfers.expectEvents(isStrict = false) {
sequence( sequence(
@ -368,12 +367,38 @@ class StateMachineManagerTests {
}, },
expect(match = { it.message is SessionConfirm }) { expect(match = { it.message is SessionConfirm }) {
it.message as SessionConfirm it.message as SessionConfirm
require(it.from == notary1.id) assertEquals(it.from, notary1.id)
} }
) )
} }
} }
@Test
fun `other side ends before doing expected send`() {
node2.services.registerFlowInitiator(ReceiveFlow::class) { NoOpFlow() }
val resultFuture = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork()
assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy {
resultFuture.getOrThrow()
}.withMessageContaining(String::class.java.name)
}
@Test
fun `non-FlowException thrown on other side`() {
node2.services.registerFlowInitiator(ReceiveFlow::class) { ExceptionFlow { Exception("evil bug!") } }
val resultFuture = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork()
val exceptionResult = assertFailsWith(FlowSessionException::class) {
resultFuture.getOrThrow()
}
assertThat(exceptionResult.message).doesNotContain("evil bug!")
assertSessionTransfers(
node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1,
node2 sent erroredEnd() to node1
)
}
@Test @Test
fun `FlowException thrown on other side`() { fun `FlowException thrown on other side`() {
val erroringFlowFuture = node2.initiateSingleShotFlow(ReceiveFlow::class) { val erroringFlowFuture = node2.initiateSingleShotFlow(ReceiveFlow::class) {
@ -384,7 +409,7 @@ class StateMachineManagerTests {
assertThatExceptionOfType(MyFlowException::class.java) assertThatExceptionOfType(MyFlowException::class.java)
.isThrownBy { receivingFiber.resultFuture.getOrThrow() } .isThrownBy { receivingFiber.resultFuture.getOrThrow() }
.withMessage("Nothing useful") .withMessage("Nothing useful")
.withStackTraceContaining("ReceiveFlow") // Make sure the stack trace is that of the receiving flow .withStackTraceContaining(ReceiveFlow::class.java.name) // Make sure the stack trace is that of the receiving flow
databaseTransaction(node2.database) { databaseTransaction(node2.database) {
assertThat(node2.checkpointStorage.checkpoints()).isEmpty() assertThat(node2.checkpointStorage.checkpoints()).isEmpty()
} }
@ -394,10 +419,10 @@ class StateMachineManagerTests {
assertSessionTransfers( assertSessionTransfers(
node1 sent sessionInit(ReceiveFlow::class) to node2, node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1, node2 sent sessionConfirm to node1,
node2 sent sessionEnd(errorFlow.exceptionThrown) to node1 node2 sent erroredEnd(errorFlow.exceptionThrown) to node1
) )
// Make sure the original stack trace isn't sent down the wire // Make sure the original stack trace isn't sent down the wire
assertThat((sessionTransfers.last().message as SessionEnd).errorResponse!!.stackTrace).isEmpty() assertThat((sessionTransfers.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty()
} }
@Test @Test
@ -450,7 +475,7 @@ class StateMachineManagerTests {
node1 sent sessionInit(ReceiveFlow::class) to node2, node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1, node2 sent sessionConfirm to node1,
node2 sent sessionData("Hello") to node1, node2 sent sessionData("Hello") to node1,
node1 sent sessionEnd() to node2 // Unexpected session-end node1 sent erroredEnd() to node2
) )
} }
@ -496,11 +521,29 @@ class StateMachineManagerTests {
ptx.signWith(node1.services.legalIdentityKey) ptx.signWith(node1.services.legalIdentityKey)
val stx = ptx.toSignedTransaction() val stx = ptx.toSignedTransaction()
val future1 = node2.services.startFlow(WaitingFlows.Waiter(stx.id)).resultFuture val committerFiber = node1
val future2 = node1.services.startFlow(WaitingFlows.Committer(stx, node2.info.legalIdentity)).resultFuture .initiateSingleShotFlow(WaitingFlows.Waiter::class) { WaitingFlows.Committer(it) }
.map { it.stateMachine }
val waiterStx = node2.services.startFlow(WaitingFlows.Waiter(stx, node1.info.legalIdentity)).resultFuture
net.runNetwork() net.runNetwork()
future1.getOrThrow() assertThat(waiterStx.getOrThrow()).isEqualTo(committerFiber.getOrThrow().resultFuture.getOrThrow())
future2.getOrThrow() }
@Test
fun `committer throws exception before calling the finality flow`() {
val ptx = TransactionBuilder(notary = notary1.info.notaryIdentity)
ptx.addOutputState(DummyState())
ptx.signWith(node1.services.legalIdentityKey)
val stx = ptx.toSignedTransaction()
node1.services.registerFlowInitiator(WaitingFlows.Waiter::class) {
WaitingFlows.Committer(it) { throw Exception("Error") }
}
val waiter = node2.services.startFlow(WaitingFlows.Waiter(stx, node1.info.legalIdentity)).resultFuture
net.runNetwork()
assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy {
waiter.getOrThrow()
}
} }
@ -522,12 +565,10 @@ class StateMachineManagerTests {
} }
private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload) private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload)
private val sessionConfirm = SessionConfirm(0, 0) private val sessionConfirm = SessionConfirm(0, 0)
private fun sessionData(payload: Any) = SessionData(0, payload) private fun sessionData(payload: Any) = SessionData(0, payload)
private val normalEnd = NormalSessionEnd(0)
private fun sessionEnd(error: FlowException? = null) = SessionEnd(0, error) private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse)
private fun assertSessionTransfers(vararg expected: SessionTransfer) { private fun assertSessionTransfers(vararg expected: SessionTransfer) {
assertThat(sessionTransfers).containsExactly(*expected) assertThat(sessionTransfers).containsExactly(*expected)
@ -557,7 +598,8 @@ class StateMachineManagerTests {
is SessionData -> message.copy(recipientSessionId = 0) is SessionData -> message.copy(recipientSessionId = 0)
is SessionInit -> message.copy(initiatorSessionId = 0) is SessionInit -> message.copy(initiatorSessionId = 0)
is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0) is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0)
is SessionEnd -> message.copy(recipientSessionId = 0) is NormalSessionEnd -> message.copy(recipientSessionId = 0)
is ErrorSessionEnd -> message.copy(recipientSessionId = 0)
else -> message else -> message
} }
} }
@ -578,7 +620,7 @@ class StateMachineManagerTests {
} }
private class SendFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic<Unit>() { private class SendFlow(val payload: String, vararg val otherParties: Party) : FlowLogic<Unit>() {
init { init {
require(otherParties.isNotEmpty()) require(otherParties.isNotEmpty())
} }
@ -595,11 +637,11 @@ class StateMachineManagerTests {
require(otherParties.isNotEmpty()) require(otherParties.isNotEmpty())
} }
@Transient var receivedPayloads: List<Any> = emptyList() @Transient var receivedPayloads: List<String> = emptyList()
@Suspendable @Suspendable
override fun call() { override fun call() {
receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } } receivedPayloads = otherParties.map { receive<String>(it).unwrap { it } }
if (nonTerminating) { if (nonTerminating) {
Fiber.park() Fiber.park()
} }
@ -630,23 +672,26 @@ class StateMachineManagerTests {
} }
} }
private class MyFlowException(message: String) : FlowException(message) { private class MyFlowException(override val message: String) : FlowException() {
override fun equals(other: Any?): Boolean = other is MyFlowException && other.message == this.message override fun equals(other: Any?): Boolean = other is MyFlowException && other.message == this.message
override fun hashCode(): Int = message?.hashCode() ?: 31 override fun hashCode(): Int = message.hashCode()
} }
private object WaitingFlows { private object WaitingFlows {
class Waiter(private val hash: SecureHash) : FlowLogic<Unit>() { class Waiter(val stx: SignedTransaction, val otherParty: Party) : FlowLogic<SignedTransaction>() {
@Suspendable @Suspendable
override fun call() { override fun call(): SignedTransaction {
waitForLedgerCommit(hash) send(otherParty, stx)
return waitForLedgerCommit(stx.id)
} }
} }
class Committer(private val stx: SignedTransaction, private val otherParty: Party) : FlowLogic<Unit>() { class Committer(val otherParty: Party, val throwException: (() -> Exception)? = null) : FlowLogic<SignedTransaction>() {
@Suspendable @Suspendable
override fun call() { override fun call(): SignedTransaction {
subFlow(FinalityFlow(stx, setOf(otherParty))) val stx = receive<SignedTransaction>(otherParty).unwrap { it }
if (throwException != null) throw throwException.invoke()
return subFlow(FinalityFlow(stx, setOf(otherParty))).single()
} }
} }
} }

View File

@ -282,6 +282,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
* parameter set to -1 (the default) which simply runs as many rounds as necessary to result in network * parameter set to -1 (the default) which simply runs as many rounds as necessary to result in network
* stability (no nodes sent any messages in the last round). * stability (no nodes sent any messages in the last round).
*/ */
@JvmOverloads
fun runNetwork(rounds: Int = -1) { fun runNetwork(rounds: Int = -1) {
check(!networkSendManuallyPumped) check(!networkSendManuallyPumped)
fun pumpAll() = messagingNetwork.endpoints.map { it.pumpReceive(false) } fun pumpAll() = messagingNetwork.endpoints.map { it.pumpReceive(false) }
@ -324,6 +325,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
* Sets up a network with the requested number of nodes (defaulting to two), with one or more service nodes that * Sets up a network with the requested number of nodes (defaulting to two), with one or more service nodes that
* run a notary, network map, any oracles etc. Can't be combined with [createTwoNodes]. * run a notary, network map, any oracles etc. Can't be combined with [createTwoNodes].
*/ */
@JvmOverloads
fun createSomeNodes(numPartyNodes: Int = 2, nodeFactory: Factory = defaultFactory, notaryKeyPair: KeyPair? = DUMMY_NOTARY_KEY): BasketOfNodes { fun createSomeNodes(numPartyNodes: Int = 2, nodeFactory: Factory = defaultFactory, notaryKeyPair: KeyPair? = DUMMY_NOTARY_KEY): BasketOfNodes {
require(nodes.isEmpty()) require(nodes.isEmpty())
val notaryServiceInfo = ServiceInfo(SimpleNotaryService.type) val notaryServiceInfo = ServiceInfo(SimpleNotaryService.type)