Making sure exceptions thrown while fibers are suspended are handled properly

This commit is contained in:
Shams Asari 2017-01-20 16:21:31 +00:00
parent 97ca6e7d8b
commit 052a660c1b
3 changed files with 77 additions and 49 deletions

View File

@ -30,7 +30,7 @@ abstract class FlowLogic<out T> {
val logger: Logger get() = stateMachine.logger
/** Returns a wrapped [UUID] object that identifies this state machine run (i.e. subflows have the same identifier as their parents). */
val runId: StateMachineRunId get() = sessionFlow.stateMachine.id
val runId: StateMachineRunId get() = stateMachine.id
/**
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is

View File

@ -6,23 +6,14 @@ import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionState
import net.corda.core.crypto.Party
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId
import net.corda.core.node.ServiceHub
import net.corda.core.random63BitValue
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.trace
import net.corda.flows.BroadcastTransactionFlow
import net.corda.flows.FinalityFlow
import net.corda.flows.ResolveTransactionsFlow
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.statemachine.StateMachineManager.FlowSession
import net.corda.node.services.statemachine.StateMachineManager.FlowSessionState
@ -55,12 +46,12 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
}
// These fields shouldn't be serialised, so they are marked @Transient.
@Transient lateinit override var serviceHub: ServiceHubInternal
@Transient override lateinit var serviceHub: ServiceHubInternal
@Transient internal lateinit var database: Database
@Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit
@Transient internal lateinit var actionOnEnd: () -> Unit
@Transient internal lateinit var database: Database
@Transient internal var fromCheckpoint: Boolean = false
@Transient internal var txTrampoline: Transaction? = null
@Transient private var txTrampoline: Transaction? = null
@Transient private var _logger: Logger? = null
override val logger: Logger get() {
@ -255,30 +246,28 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
txTrampoline = TransactionManager.currentOrNull()
StrandLocalTransactionManager.setThreadLocalTx(null)
ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>)
var exceptionDuringSuspend: Throwable? = null
parkAndSerialize { fiber, serializer ->
logger.trace { "Suspended on $ioRequest" }
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
StrandLocalTransactionManager.setThreadLocalTx(txTrampoline)
txTrampoline = null
try {
StrandLocalTransactionManager.setThreadLocalTx(txTrampoline)
txTrampoline = null
actionOnSuspend(ioRequest)
} catch (t: Throwable) {
// Do not throw exception again - Quasar completely bins it.
logger.warn("Captured exception which was swallowed by Quasar for $logic at ${fiber.stackTrace.toList().joinToString("\n")}", t)
// TODO When error handling is introduced, look into whether we should be deleting the checkpoint and
// completing the Future
processException(t)
// Quasar does not terminate the fiber properly if an exception occurs during a suspend. We have to
// resume the fiber just so that we can throw it when it's running.
exceptionDuringSuspend = t
resume(scheduler)
}
}
logger.trace { "Resumed from $ioRequest" }
createTransaction()
}
private fun processException(t: Throwable) {
// This can get called in actionOnSuspend *after* we commit the database transaction, so optionally open a new one here.
createDatabaseTransaction(database)
actionOnEnd()
_resultFuture?.setException(t)
createTransaction()
// TODO Now that we're throwing outside of the suspend the FlowLogic can catch it. We need Quasar to terminate
// the fiber when exceptions occur inside a suspend.
exceptionDuringSuspend?.let { throw it }
logger.trace { "Resumed from $ioRequest" }
}
internal fun resume(scheduler: FiberScheduler) {

View File

@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand.UncaughtExceptionHandler
import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.issuedBy
@ -10,7 +11,9 @@ import net.corda.core.crypto.generateKeyPair
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.getOrThrow
import net.corda.core.map
import net.corda.core.random63BitValue
import net.corda.core.rootCause
import net.corda.core.serialization.OpaqueBytes
import net.corda.core.serialization.deserialize
import net.corda.flows.CashCommand
@ -27,7 +30,7 @@ import net.corda.testing.node.MockNetwork
import net.corda.testing.node.MockNetwork.MockNode
import net.corda.testing.sequence
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy
import org.junit.After
import org.junit.Before
import org.junit.Test
@ -84,9 +87,34 @@ class StateMachineManagerTests {
assertThat(flow.lazyTime).isNotNull()
}
@Test
fun `exception while fiber suspended`() {
node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(2, it) }
val flow = ReceiveFlow(node2.info.legalIdentity)
val fiber = node1.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception
val exceptionDuringSuspend = Exception("Thrown during suspend")
fiber.actionOnSuspend = {
throw exceptionDuringSuspend
}
var uncaughtException: Throwable? = null
fiber.uncaughtExceptionHandler = UncaughtExceptionHandler { f, e ->
uncaughtException = e
}
net.runNetwork()
assertThatThrownBy {
fiber.resultFuture.getOrThrow()
}.isSameAs(exceptionDuringSuspend)
assertThat(node1.smm.allStateMachines).isEmpty()
// Make sure it doesn't get swallowed up
assertThat(uncaughtException?.rootCause).isSameAs(exceptionDuringSuspend)
// Make sure the fiber does actually terminate
assertThat(fiber.isTerminated).isTrue()
}
@Test
fun `flow restarted just after receiving payload`() {
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveThenSuspendFlow(it) }
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
val payload = random63BitValue()
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity))
@ -96,7 +124,7 @@ class StateMachineManagerTests {
node2.acceptableLiveFiberCountOnStop = 1
node2.stop()
net.runNetwork()
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveThenSuspendFlow>(node1)
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveFlow>(node1)
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload)
}
@ -138,13 +166,13 @@ class StateMachineManagerTests {
@Test
fun `flow loaded from checkpoint will respond to messages from before start`() {
val payload = random63BitValue()
node1.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { SendFlow(payload, it) }
node2.services.startFlow(ReceiveThenSuspendFlow(node1.info.legalIdentity)) // Prepare checkpointed receive flow
node1.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(payload, it) }
node2.services.startFlow(ReceiveFlow(node1.info.legalIdentity).nonTerminating()) // Prepare checkpointed receive flow
// Make sure the add() has finished initial processing.
node2.smm.executor.flush()
node2.disableDBCloseOnStop()
node2.stop() // kill receiver
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveThenSuspendFlow>(node1)
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveFlow>(node1)
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload)
}
@ -202,13 +230,13 @@ class StateMachineManagerTests {
fun `sending to multiple parties`() {
val node3 = net.createNode(node1.info.address)
net.runNetwork()
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveThenSuspendFlow(it) }
node3.services.registerFlowInitiator(SendFlow::class) { ReceiveThenSuspendFlow(it) }
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
node3.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
val payload = random63BitValue()
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity))
net.runNetwork()
val node2Flow = node2.getSingleFlow<ReceiveThenSuspendFlow>().first
val node3Flow = node3.getSingleFlow<ReceiveThenSuspendFlow>().first
val node2Flow = node2.getSingleFlow<ReceiveFlow>().first
val node3Flow = node3.getSingleFlow<ReceiveFlow>().first
assertThat(node2Flow.receivedPayloads[0]).isEqualTo(payload)
assertThat(node3Flow.receivedPayloads[0]).isEqualTo(payload)
@ -236,9 +264,9 @@ class StateMachineManagerTests {
net.runNetwork()
val node2Payload = random63BitValue()
val node3Payload = random63BitValue()
node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { SendFlow(node2Payload, it) }
node3.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { SendFlow(node3Payload, it) }
val multiReceiveFlow = ReceiveThenSuspendFlow(node2.info.legalIdentity, node3.info.legalIdentity)
node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node2Payload, it) }
node3.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node3Payload, it) }
val multiReceiveFlow = ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity).nonTerminating()
node1.services.startFlow(multiReceiveFlow)
node1.acceptableLiveFiberCountOnStop = 1
net.runNetwork()
@ -246,14 +274,14 @@ class StateMachineManagerTests {
assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(node3Payload)
assertSessionTransfers(node2,
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1,
node2 sent sessionData(node2Payload) to node1,
node2 sent sessionEnd to node1
)
assertSessionTransfers(node3,
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node3,
node1 sent sessionInit(ReceiveFlow::class) to node3,
node3 sent sessionConfirm to node1,
node3 sent sessionData(node3Payload) to node1,
node3 sent sessionEnd to node1
@ -329,12 +357,14 @@ class StateMachineManagerTests {
@Test
fun `exception thrown on other side`() {
node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { ExceptionFlow }
val future = node1.services.startFlow(ReceiveThenSuspendFlow(node2.info.legalIdentity)).resultFuture
val erroringFiber = node2.initiateSingleShotFlow(ReceiveFlow::class) { ExceptionFlow }.map { it.stateMachine as FlowStateMachineImpl }
val receivingFiber = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)) as FlowStateMachineImpl
net.runNetwork()
assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowException::class.java)
assertThatThrownBy { receivingFiber.resultFuture.getOrThrow() }.isInstanceOf(FlowException::class.java)
assertThat(receivingFiber.isTerminated).isTrue()
assertThat(erroringFiber.getOrThrow().isTerminated).isTrue()
assertSessionTransfers(
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1,
node2 sent sessionEnd to node1
)
@ -433,7 +463,9 @@ class StateMachineManagerTests {
}
private class ReceiveThenSuspendFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
private var nonTerminating: Boolean = false
init {
require(otherParties.isNotEmpty())
}
@ -443,7 +475,14 @@ class StateMachineManagerTests {
@Suspendable
override fun call() {
receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } }
Fiber.park()
if (nonTerminating) {
Fiber.park()
}
}
fun nonTerminating(): ReceiveFlow {
nonTerminating = true
return this
}
}