mirror of
https://github.com/corda/corda.git
synced 2025-06-18 15:18:16 +00:00
Allow received FlowException to propagate further to initiating flow
This commit is contained in:
@ -49,7 +49,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
@Transient override lateinit var serviceHub: ServiceHubInternal
|
@Transient override lateinit var serviceHub: ServiceHubInternal
|
||||||
@Transient internal lateinit var database: Database
|
@Transient internal lateinit var database: Database
|
||||||
@Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit
|
@Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit
|
||||||
@Transient internal lateinit var actionOnEnd: (FlowException?) -> Unit
|
@Transient internal lateinit var actionOnEnd: (Pair<FlowException, Boolean>?) -> Unit
|
||||||
@Transient internal var fromCheckpoint: Boolean = false
|
@Transient internal var fromCheckpoint: Boolean = false
|
||||||
@Transient private var txTrampoline: Transaction? = null
|
@Transient private var txTrampoline: Transaction? = null
|
||||||
|
|
||||||
@ -85,15 +85,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
val result = try {
|
val result = try {
|
||||||
logic.call()
|
logic.call()
|
||||||
} catch (e: FlowException) {
|
} catch (e: FlowException) {
|
||||||
if (e.stackTrace[0].className == javaClass.name) {
|
// Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive).
|
||||||
// FlowException was propagated to us as it's stack trace points to this internal class (see suspendAndExpectReceive).
|
val propagated = e.stackTrace[0].className == javaClass.name
|
||||||
// If we've got to here then the flow doesn't want to handle it and so we end, but we don't propagate
|
actionOnEnd(Pair(e, propagated))
|
||||||
// the exception further as it's not relevant to anyone else.
|
|
||||||
actionOnEnd(null)
|
|
||||||
} else {
|
|
||||||
// FLowException came from this flow
|
|
||||||
actionOnEnd(e)
|
|
||||||
}
|
|
||||||
_resultFuture?.setException(e)
|
_resultFuture?.setException(e)
|
||||||
return
|
return
|
||||||
} catch (t: Throwable) {
|
} catch (t: Throwable) {
|
||||||
@ -221,7 +215,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
@Suspendable
|
@Suspendable
|
||||||
private fun startNewSession(otherParty: Party.Full, sessionFlow: FlowLogic<*>, firstPayload: Any?, waitForConfirmation: Boolean): FlowSession {
|
private fun startNewSession(otherParty: Party.Full, sessionFlow: FlowLogic<*>, firstPayload: Any?, waitForConfirmation: Boolean): FlowSession {
|
||||||
logger.trace { "Initiating a new session with $otherParty" }
|
logger.trace { "Initiating a new session with $otherParty" }
|
||||||
val session = FlowSession(sessionFlow, random63BitValue(), FlowSessionState.Initiating(otherParty))
|
val session = FlowSession(sessionFlow, random63BitValue(), null, FlowSessionState.Initiating(otherParty))
|
||||||
openSessions[Pair(sessionFlow, otherParty)] = session
|
openSessions[Pair(sessionFlow, otherParty)] = session
|
||||||
val counterpartyFlow = sessionFlow.getCounterpartyMarker(otherParty).name
|
val counterpartyFlow = sessionFlow.getCounterpartyMarker(otherParty).name
|
||||||
val sessionInit = SessionInit(session.ourSessionId, counterpartyFlow, firstPayload)
|
val sessionInit = SessionInit(session.ourSessionId, counterpartyFlow, firstPayload)
|
||||||
|
@ -291,7 +291,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
val session = try {
|
val session = try {
|
||||||
val flow = flowFactory(sender)
|
val flow = flowFactory(sender)
|
||||||
val fiber = createFiber(flow)
|
val fiber = createFiber(flow)
|
||||||
val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(sender, otherPartySessionId))
|
val session = FlowSession(flow, random63BitValue(), sender, FlowSessionState.Initiated(sender, otherPartySessionId))
|
||||||
if (sessionInit.firstPayload != null) {
|
if (sessionInit.firstPayload != null) {
|
||||||
session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload))
|
session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload))
|
||||||
}
|
}
|
||||||
@ -345,7 +345,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
processIORequest(ioRequest)
|
processIORequest(ioRequest)
|
||||||
decrementLiveFibers()
|
decrementLiveFibers()
|
||||||
}
|
}
|
||||||
fiber.actionOnEnd = { errorResponse: FlowException? ->
|
fiber.actionOnEnd = { errorResponse: Pair<FlowException, Boolean>? ->
|
||||||
try {
|
try {
|
||||||
fiber.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
fiber.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||||
mutex.locked {
|
mutex.locked {
|
||||||
@ -367,9 +367,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, errorResponse: FlowException?) {
|
private fun endAllFiberSessions(fiber: FlowStateMachineImpl<*>, errorResponse: Pair<FlowException, Boolean>?) {
|
||||||
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
|
// TODO Blanking the stack trace prevents the receiving flow from filling in its own stack trace
|
||||||
(errorResponse as java.lang.Throwable?)?.stackTrace = emptyArray()
|
// @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
|
||||||
|
// (errorResponse?.first as java.lang.Throwable?)?.stackTrace = emptyArray()
|
||||||
openSessions.values.removeIf { session ->
|
openSessions.values.removeIf { session ->
|
||||||
if (session.fiber == fiber) {
|
if (session.fiber == fiber) {
|
||||||
session.endSession(errorResponse)
|
session.endSession(errorResponse)
|
||||||
@ -380,15 +381,23 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun FlowSession.endSession(errorResponse: FlowException?) {
|
private fun FlowSession.endSession(errorResponse: Pair<FlowException, Boolean>?) {
|
||||||
val initiatedState = state as? Initiated
|
val initiatedState = state as? Initiated ?: return
|
||||||
if (initiatedState != null) {
|
val propagatedException = errorResponse?.let {
|
||||||
sendSessionMessage(
|
val (exception, propagated) = it
|
||||||
initiatedState.peerParty,
|
if (propagated) {
|
||||||
SessionEnd(initiatedState.peerSessionId, errorResponse),
|
// This exception was propagated to us. We only propagate it down the invocation chain to the flow that
|
||||||
fiber)
|
// initiated us, not to flows we've started sessions with.
|
||||||
recentlyClosedSessions[ourSessionId] = initiatedState.peerParty
|
if (initiatingParty != null) exception else null
|
||||||
|
} else {
|
||||||
|
exception // Our local flow threw the exception so propagate it
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
sendSessionMessage(
|
||||||
|
initiatedState.peerParty,
|
||||||
|
SessionEnd(initiatedState.peerSessionId, propagatedException),
|
||||||
|
fiber)
|
||||||
|
recentlyClosedSessions[ourSessionId] = initiatedState.peerParty
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun startFiber(fiber: FlowStateMachineImpl<*>) {
|
private fun startFiber(fiber: FlowStateMachineImpl<*>) {
|
||||||
@ -508,6 +517,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
data class FlowSession(
|
data class FlowSession(
|
||||||
val flow: FlowLogic<*>,
|
val flow: FlowLogic<*>,
|
||||||
val ourSessionId: Long,
|
val ourSessionId: Long,
|
||||||
|
val initiatingParty: Party?,
|
||||||
var state: FlowSessionState,
|
var state: FlowSessionState,
|
||||||
@Volatile var waitingForResponse: Boolean = false
|
@Volatile var waitingForResponse: Boolean = false
|
||||||
) {
|
) {
|
||||||
|
@ -390,8 +390,23 @@ class StateMachineManagerTests {
|
|||||||
node2 sent sessionConfirm to node1,
|
node2 sent sessionConfirm to node1,
|
||||||
node2 sent sessionEnd(errorFlow.exceptionThrown) to node1
|
node2 sent sessionEnd(errorFlow.exceptionThrown) to node1
|
||||||
)
|
)
|
||||||
// Make sure the original stack trace isn't sent down the wire
|
// TODO see StateMachineManager.endAllFiberSessions
|
||||||
assertThat((sessionTransfers.last().message as SessionEnd).errorResponse!!.stackTrace).isEmpty()
|
// // Make sure the original stack trace isn't sent down the wire
|
||||||
|
// assertThat((sessionTransfers.last().message as SessionEnd).errorResponse!!.stackTrace).isEmpty()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `FlowException propagated in invocation chain`() {
|
||||||
|
val node3 = net.createNode(node1.info.address)
|
||||||
|
net.runNetwork()
|
||||||
|
|
||||||
|
node3.initiateSingleShotFlow(ReceiveFlow::class) { ExceptionFlow { MyFlowException("Chain") } }
|
||||||
|
node2.initiateSingleShotFlow(ReceiveFlow::class) { ReceiveFlow(node3.info.legalIdentity) }
|
||||||
|
val receivingFiber = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity))
|
||||||
|
net.runNetwork()
|
||||||
|
assertThatExceptionOfType(MyFlowException::class.java)
|
||||||
|
.isThrownBy { receivingFiber.resultFuture.getOrThrow() }
|
||||||
|
.withMessage("Chain")
|
||||||
}
|
}
|
||||||
|
|
||||||
private class SendAndReceiveFlow(val otherParty: Party.Full, val payload: Any) : FlowLogic<Unit>() {
|
private class SendAndReceiveFlow(val otherParty: Party.Full, val payload: Any) : FlowLogic<Unit>() {
|
||||||
@ -402,7 +417,7 @@ class StateMachineManagerTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `FlowException thrown and there is a 3rd party flow`() {
|
fun `FlowException thrown and there is a 3rd unrelated party flow`() {
|
||||||
val node3 = net.createNode(node1.info.address)
|
val node3 = net.createNode(node1.info.address)
|
||||||
net.runNetwork()
|
net.runNetwork()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user