diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 334622ea41..2abf0845bf 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -49,7 +49,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, @Transient override lateinit var serviceHub: ServiceHubInternal @Transient internal lateinit var database: Database @Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit - @Transient internal lateinit var actionOnEnd: (FlowException?) -> Unit + @Transient internal lateinit var actionOnEnd: (Pair?) -> Unit @Transient internal var fromCheckpoint: Boolean = false @Transient private var txTrampoline: Transaction? = null @@ -85,15 +85,9 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, val result = try { logic.call() } catch (e: FlowException) { - if (e.stackTrace[0].className == javaClass.name) { - // FlowException was propagated to us as it's stack trace points to this internal class (see suspendAndExpectReceive). - // If we've got to here then the flow doesn't want to handle it and so we end, but we don't propagate - // the exception further as it's not relevant to anyone else. - actionOnEnd(null) - } else { - // FLowException came from this flow - actionOnEnd(e) - } + // Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive). + val propagated = e.stackTrace[0].className == javaClass.name + actionOnEnd(Pair(e, propagated)) _resultFuture?.setException(e) return } catch (t: Throwable) { @@ -221,7 +215,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, @Suspendable private fun startNewSession(otherParty: Party.Full, sessionFlow: FlowLogic<*>, firstPayload: Any?, waitForConfirmation: Boolean): FlowSession { logger.trace { "Initiating a new session with $otherParty" } - val session = FlowSession(sessionFlow, random63BitValue(), FlowSessionState.Initiating(otherParty)) + val session = FlowSession(sessionFlow, random63BitValue(), null, FlowSessionState.Initiating(otherParty)) openSessions[Pair(sessionFlow, otherParty)] = session val counterpartyFlow = sessionFlow.getCounterpartyMarker(otherParty).name val sessionInit = SessionInit(session.ourSessionId, counterpartyFlow, firstPayload) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index b1aef66aa8..30a5002092 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -291,7 +291,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, val session = try { val flow = flowFactory(sender) val fiber = createFiber(flow) - val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(sender, otherPartySessionId)) + val session = FlowSession(flow, random63BitValue(), sender, FlowSessionState.Initiated(sender, otherPartySessionId)) if (sessionInit.firstPayload != null) { session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload)) } @@ -345,7 +345,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, processIORequest(ioRequest) decrementLiveFibers() } - fiber.actionOnEnd = { errorResponse: FlowException? -> + fiber.actionOnEnd = { errorResponse: Pair? -> try { fiber.logic.progressTracker?.currentStep = ProgressTracker.DONE mutex.locked { @@ -367,9 +367,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } } - private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, errorResponse: FlowException?) { - @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") - (errorResponse as java.lang.Throwable?)?.stackTrace = emptyArray() + private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, errorResponse: Pair?) { + // TODO Blanking the stack trace prevents the receiving flow from filling in its own stack trace +// @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") +// (errorResponse?.first as java.lang.Throwable?)?.stackTrace = emptyArray() openSessions.values.removeIf { session -> if (session.fiber == fiber) { session.endSession(errorResponse) @@ -380,15 +381,23 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } } - private fun FlowSession.endSession(errorResponse: FlowException?) { - val initiatedState = state as? Initiated - if (initiatedState != null) { - sendSessionMessage( - initiatedState.peerParty, - SessionEnd(initiatedState.peerSessionId, errorResponse), - fiber) - recentlyClosedSessions[ourSessionId] = initiatedState.peerParty + private fun FlowSession.endSession(errorResponse: Pair?) { + val initiatedState = state as? Initiated ?: return + val propagatedException = errorResponse?.let { + val (exception, propagated) = it + if (propagated) { + // This exception was propagated to us. We only propagate it down the invocation chain to the flow that + // initiated us, not to flows we've started sessions with. + if (initiatingParty != null) exception else null + } else { + exception // Our local flow threw the exception so propagate it + } } + sendSessionMessage( + initiatedState.peerParty, + SessionEnd(initiatedState.peerSessionId, propagatedException), + fiber) + recentlyClosedSessions[ourSessionId] = initiatedState.peerParty } private fun startFiber(fiber: FlowStateMachineImpl<*>) { @@ -508,6 +517,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, data class FlowSession( val flow: FlowLogic<*>, val ourSessionId: Long, + val initiatingParty: Party?, var state: FlowSessionState, @Volatile var waitingForResponse: Boolean = false ) { diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt index 856bdffb53..82b26a4148 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt @@ -390,8 +390,23 @@ class StateMachineManagerTests { node2 sent sessionConfirm to node1, node2 sent sessionEnd(errorFlow.exceptionThrown) to node1 ) - // Make sure the original stack trace isn't sent down the wire - assertThat((sessionTransfers.last().message as SessionEnd).errorResponse!!.stackTrace).isEmpty() + // TODO see StateMachineManager.endAllFiberSessions +// // Make sure the original stack trace isn't sent down the wire +// assertThat((sessionTransfers.last().message as SessionEnd).errorResponse!!.stackTrace).isEmpty() + } + + @Test + fun `FlowException propagated in invocation chain`() { + val node3 = net.createNode(node1.info.address) + net.runNetwork() + + node3.initiateSingleShotFlow(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } } + node2.initiateSingleShotFlow(ReceiveFlow::class) { ReceiveFlow(node3.info.legalIdentity) } + val receivingFiber = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)) + net.runNetwork() + assertThatExceptionOfType(MyFlowException::class.java) + .isThrownBy { receivingFiber.resultFuture.getOrThrow() } + .withMessage("Chain") } private class SendAndReceiveFlow(val otherParty: Party.Full, val payload: Any) : FlowLogic() { @@ -402,7 +417,7 @@ class StateMachineManagerTests { } @Test - fun `FlowException thrown and there is a 3rd party flow`() { + fun `FlowException thrown and there is a 3rd unrelated party flow`() { val node3 = net.createNode(node1.info.address) net.runNetwork()