Making sure exceptions thrown while fibers are suspended are handled properly

This commit is contained in:
Shams Asari 2017-01-20 16:21:31 +00:00
parent 97ca6e7d8b
commit 052a660c1b
3 changed files with 77 additions and 49 deletions

View File

@ -30,7 +30,7 @@ abstract class FlowLogic<out T> {
val logger: Logger get() = stateMachine.logger val logger: Logger get() = stateMachine.logger
/** Returns a wrapped [UUID] object that identifies this state machine run (i.e. subflows have the same identifier as their parents). */ /** Returns a wrapped [UUID] object that identifies this state machine run (i.e. subflows have the same identifier as their parents). */
val runId: StateMachineRunId get() = sessionFlow.stateMachine.id val runId: StateMachineRunId get() = stateMachine.id
/** /**
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is * Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is

View File

@ -6,23 +6,14 @@ import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.Strand
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionState
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.node.ServiceHub
import net.corda.core.random63BitValue import net.corda.core.random63BitValue
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.trace import net.corda.core.utilities.trace
import net.corda.flows.BroadcastTransactionFlow
import net.corda.flows.FinalityFlow
import net.corda.flows.ResolveTransactionsFlow
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.statemachine.StateMachineManager.FlowSession import net.corda.node.services.statemachine.StateMachineManager.FlowSession
import net.corda.node.services.statemachine.StateMachineManager.FlowSessionState import net.corda.node.services.statemachine.StateMachineManager.FlowSessionState
@ -55,12 +46,12 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} }
// These fields shouldn't be serialised, so they are marked @Transient. // These fields shouldn't be serialised, so they are marked @Transient.
@Transient lateinit override var serviceHub: ServiceHubInternal @Transient override lateinit var serviceHub: ServiceHubInternal
@Transient internal lateinit var database: Database
@Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit @Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit
@Transient internal lateinit var actionOnEnd: () -> Unit @Transient internal lateinit var actionOnEnd: () -> Unit
@Transient internal lateinit var database: Database
@Transient internal var fromCheckpoint: Boolean = false @Transient internal var fromCheckpoint: Boolean = false
@Transient internal var txTrampoline: Transaction? = null @Transient private var txTrampoline: Transaction? = null
@Transient private var _logger: Logger? = null @Transient private var _logger: Logger? = null
override val logger: Logger get() { override val logger: Logger get() {
@ -255,30 +246,28 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
txTrampoline = TransactionManager.currentOrNull() txTrampoline = TransactionManager.currentOrNull()
StrandLocalTransactionManager.setThreadLocalTx(null) StrandLocalTransactionManager.setThreadLocalTx(null)
ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>) ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>)
var exceptionDuringSuspend: Throwable? = null
parkAndSerialize { fiber, serializer -> parkAndSerialize { fiber, serializer ->
logger.trace { "Suspended on $ioRequest" } logger.trace { "Suspended on $ioRequest" }
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB // restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
StrandLocalTransactionManager.setThreadLocalTx(txTrampoline)
txTrampoline = null
try { try {
StrandLocalTransactionManager.setThreadLocalTx(txTrampoline)
txTrampoline = null
actionOnSuspend(ioRequest) actionOnSuspend(ioRequest)
} catch (t: Throwable) { } catch (t: Throwable) {
// Do not throw exception again - Quasar completely bins it. // Quasar does not terminate the fiber properly if an exception occurs during a suspend. We have to
logger.warn("Captured exception which was swallowed by Quasar for $logic at ${fiber.stackTrace.toList().joinToString("\n")}", t) // resume the fiber just so that we can throw it when it's running.
// TODO When error handling is introduced, look into whether we should be deleting the checkpoint and exceptionDuringSuspend = t
// completing the Future resume(scheduler)
processException(t)
} }
} }
logger.trace { "Resumed from $ioRequest" }
createTransaction()
}
private fun processException(t: Throwable) { createTransaction()
// This can get called in actionOnSuspend *after* we commit the database transaction, so optionally open a new one here. // TODO Now that we're throwing outside of the suspend the FlowLogic can catch it. We need Quasar to terminate
createDatabaseTransaction(database) // the fiber when exceptions occur inside a suspend.
actionOnEnd() exceptionDuringSuspend?.let { throw it }
_resultFuture?.setException(t) logger.trace { "Resumed from $ioRequest" }
} }
internal fun resume(scheduler: FiberScheduler) { internal fun resume(scheduler: FiberScheduler) {

View File

@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand.UncaughtExceptionHandler
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.issuedBy import net.corda.core.contracts.issuedBy
@ -10,7 +11,9 @@ import net.corda.core.crypto.generateKeyPair
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.getOrThrow import net.corda.core.getOrThrow
import net.corda.core.map
import net.corda.core.random63BitValue import net.corda.core.random63BitValue
import net.corda.core.rootCause
import net.corda.core.serialization.OpaqueBytes import net.corda.core.serialization.OpaqueBytes
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.flows.CashCommand import net.corda.flows.CashCommand
@ -27,7 +30,7 @@ import net.corda.testing.node.MockNetwork
import net.corda.testing.node.MockNetwork.MockNode import net.corda.testing.node.MockNetwork.MockNode
import net.corda.testing.sequence import net.corda.testing.sequence
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
@ -84,9 +87,34 @@ class StateMachineManagerTests {
assertThat(flow.lazyTime).isNotNull() assertThat(flow.lazyTime).isNotNull()
} }
@Test
fun `exception while fiber suspended`() {
node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(2, it) }
val flow = ReceiveFlow(node2.info.legalIdentity)
val fiber = node1.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception
val exceptionDuringSuspend = Exception("Thrown during suspend")
fiber.actionOnSuspend = {
throw exceptionDuringSuspend
}
var uncaughtException: Throwable? = null
fiber.uncaughtExceptionHandler = UncaughtExceptionHandler { f, e ->
uncaughtException = e
}
net.runNetwork()
assertThatThrownBy {
fiber.resultFuture.getOrThrow()
}.isSameAs(exceptionDuringSuspend)
assertThat(node1.smm.allStateMachines).isEmpty()
// Make sure it doesn't get swallowed up
assertThat(uncaughtException?.rootCause).isSameAs(exceptionDuringSuspend)
// Make sure the fiber does actually terminate
assertThat(fiber.isTerminated).isTrue()
}
@Test @Test
fun `flow restarted just after receiving payload`() { fun `flow restarted just after receiving payload`() {
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveThenSuspendFlow(it) } node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
val payload = random63BitValue() val payload = random63BitValue()
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity)) node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity))
@ -96,7 +124,7 @@ class StateMachineManagerTests {
node2.acceptableLiveFiberCountOnStop = 1 node2.acceptableLiveFiberCountOnStop = 1
node2.stop() node2.stop()
net.runNetwork() net.runNetwork()
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveThenSuspendFlow>(node1) val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveFlow>(node1)
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload) assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload)
} }
@ -138,13 +166,13 @@ class StateMachineManagerTests {
@Test @Test
fun `flow loaded from checkpoint will respond to messages from before start`() { fun `flow loaded from checkpoint will respond to messages from before start`() {
val payload = random63BitValue() val payload = random63BitValue()
node1.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { SendFlow(payload, it) } node1.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(payload, it) }
node2.services.startFlow(ReceiveThenSuspendFlow(node1.info.legalIdentity)) // Prepare checkpointed receive flow node2.services.startFlow(ReceiveFlow(node1.info.legalIdentity).nonTerminating()) // Prepare checkpointed receive flow
// Make sure the add() has finished initial processing. // Make sure the add() has finished initial processing.
node2.smm.executor.flush() node2.smm.executor.flush()
node2.disableDBCloseOnStop() node2.disableDBCloseOnStop()
node2.stop() // kill receiver node2.stop() // kill receiver
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveThenSuspendFlow>(node1) val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveFlow>(node1)
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload) assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload)
} }
@ -202,13 +230,13 @@ class StateMachineManagerTests {
fun `sending to multiple parties`() { fun `sending to multiple parties`() {
val node3 = net.createNode(node1.info.address) val node3 = net.createNode(node1.info.address)
net.runNetwork() net.runNetwork()
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveThenSuspendFlow(it) } node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
node3.services.registerFlowInitiator(SendFlow::class) { ReceiveThenSuspendFlow(it) } node3.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
val payload = random63BitValue() val payload = random63BitValue()
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity)) node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity))
net.runNetwork() net.runNetwork()
val node2Flow = node2.getSingleFlow<ReceiveThenSuspendFlow>().first val node2Flow = node2.getSingleFlow<ReceiveFlow>().first
val node3Flow = node3.getSingleFlow<ReceiveThenSuspendFlow>().first val node3Flow = node3.getSingleFlow<ReceiveFlow>().first
assertThat(node2Flow.receivedPayloads[0]).isEqualTo(payload) assertThat(node2Flow.receivedPayloads[0]).isEqualTo(payload)
assertThat(node3Flow.receivedPayloads[0]).isEqualTo(payload) assertThat(node3Flow.receivedPayloads[0]).isEqualTo(payload)
@ -236,9 +264,9 @@ class StateMachineManagerTests {
net.runNetwork() net.runNetwork()
val node2Payload = random63BitValue() val node2Payload = random63BitValue()
val node3Payload = random63BitValue() val node3Payload = random63BitValue()
node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { SendFlow(node2Payload, it) } node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node2Payload, it) }
node3.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { SendFlow(node3Payload, it) } node3.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node3Payload, it) }
val multiReceiveFlow = ReceiveThenSuspendFlow(node2.info.legalIdentity, node3.info.legalIdentity) val multiReceiveFlow = ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity).nonTerminating()
node1.services.startFlow(multiReceiveFlow) node1.services.startFlow(multiReceiveFlow)
node1.acceptableLiveFiberCountOnStop = 1 node1.acceptableLiveFiberCountOnStop = 1
net.runNetwork() net.runNetwork()
@ -246,14 +274,14 @@ class StateMachineManagerTests {
assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(node3Payload) assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(node3Payload)
assertSessionTransfers(node2, assertSessionTransfers(node2,
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2, node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1, node2 sent sessionConfirm to node1,
node2 sent sessionData(node2Payload) to node1, node2 sent sessionData(node2Payload) to node1,
node2 sent sessionEnd to node1 node2 sent sessionEnd to node1
) )
assertSessionTransfers(node3, assertSessionTransfers(node3,
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node3, node1 sent sessionInit(ReceiveFlow::class) to node3,
node3 sent sessionConfirm to node1, node3 sent sessionConfirm to node1,
node3 sent sessionData(node3Payload) to node1, node3 sent sessionData(node3Payload) to node1,
node3 sent sessionEnd to node1 node3 sent sessionEnd to node1
@ -329,12 +357,14 @@ class StateMachineManagerTests {
@Test @Test
fun `exception thrown on other side`() { fun `exception thrown on other side`() {
node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { ExceptionFlow } val erroringFiber = node2.initiateSingleShotFlow(ReceiveFlow::class) { ExceptionFlow }.map { it.stateMachine as FlowStateMachineImpl }
val future = node1.services.startFlow(ReceiveThenSuspendFlow(node2.info.legalIdentity)).resultFuture val receivingFiber = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)) as FlowStateMachineImpl
net.runNetwork() net.runNetwork()
assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowException::class.java) assertThatThrownBy { receivingFiber.resultFuture.getOrThrow() }.isInstanceOf(FlowException::class.java)
assertThat(receivingFiber.isTerminated).isTrue()
assertThat(erroringFiber.getOrThrow().isTerminated).isTrue()
assertSessionTransfers( assertSessionTransfers(
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2, node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1, node2 sent sessionConfirm to node1,
node2 sent sessionEnd to node1 node2 sent sessionEnd to node1
) )
@ -433,7 +463,9 @@ class StateMachineManagerTests {
} }
private class ReceiveThenSuspendFlow(vararg val otherParties: Party) : FlowLogic<Unit>() { private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
private var nonTerminating: Boolean = false
init { init {
require(otherParties.isNotEmpty()) require(otherParties.isNotEmpty())
} }
@ -443,7 +475,14 @@ class StateMachineManagerTests {
@Suspendable @Suspendable
override fun call() { override fun call() {
receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } } receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } }
Fiber.park() if (nonTerminating) {
Fiber.park()
}
}
fun nonTerminating(): ReceiveFlow {
nonTerminating = true
return this
} }
} }