diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index 01e92ec092..0f7bf0a58c 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -30,7 +30,7 @@ abstract class FlowLogic { val logger: Logger get() = stateMachine.logger /** Returns a wrapped [UUID] object that identifies this state machine run (i.e. subflows have the same identifier as their parents). */ - val runId: StateMachineRunId get() = sessionFlow.stateMachine.id + val runId: StateMachineRunId get() = stateMachine.id /** * Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 24cfb619e2..4675001e6a 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -6,23 +6,14 @@ import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.strands.Strand import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.SettableFuture -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.StateRef -import net.corda.core.contracts.TransactionState import net.corda.core.crypto.Party import net.corda.core.flows.FlowException import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowStateMachine import net.corda.core.flows.StateMachineRunId -import net.corda.core.node.ServiceHub import net.corda.core.random63BitValue -import net.corda.core.transactions.LedgerTransaction -import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.trace -import net.corda.flows.BroadcastTransactionFlow -import net.corda.flows.FinalityFlow -import net.corda.flows.ResolveTransactionsFlow import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.statemachine.StateMachineManager.FlowSession import net.corda.node.services.statemachine.StateMachineManager.FlowSessionState @@ -55,12 +46,12 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } // These fields shouldn't be serialised, so they are marked @Transient. - @Transient lateinit override var serviceHub: ServiceHubInternal + @Transient override lateinit var serviceHub: ServiceHubInternal + @Transient internal lateinit var database: Database @Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit @Transient internal lateinit var actionOnEnd: () -> Unit - @Transient internal lateinit var database: Database @Transient internal var fromCheckpoint: Boolean = false - @Transient internal var txTrampoline: Transaction? = null + @Transient private var txTrampoline: Transaction? = null @Transient private var _logger: Logger? = null override val logger: Logger get() { @@ -255,30 +246,28 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, txTrampoline = TransactionManager.currentOrNull() StrandLocalTransactionManager.setThreadLocalTx(null) ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>) + + var exceptionDuringSuspend: Throwable? = null parkAndSerialize { fiber, serializer -> logger.trace { "Suspended on $ioRequest" } // restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB - StrandLocalTransactionManager.setThreadLocalTx(txTrampoline) - txTrampoline = null try { + StrandLocalTransactionManager.setThreadLocalTx(txTrampoline) + txTrampoline = null actionOnSuspend(ioRequest) } catch (t: Throwable) { - // Do not throw exception again - Quasar completely bins it. - logger.warn("Captured exception which was swallowed by Quasar for $logic at ${fiber.stackTrace.toList().joinToString("\n")}", t) - // TODO When error handling is introduced, look into whether we should be deleting the checkpoint and - // completing the Future - processException(t) + // Quasar does not terminate the fiber properly if an exception occurs during a suspend. We have to + // resume the fiber just so that we can throw it when it's running. + exceptionDuringSuspend = t + resume(scheduler) } } - logger.trace { "Resumed from $ioRequest" } - createTransaction() - } - private fun processException(t: Throwable) { - // This can get called in actionOnSuspend *after* we commit the database transaction, so optionally open a new one here. - createDatabaseTransaction(database) - actionOnEnd() - _resultFuture?.setException(t) + createTransaction() + // TODO Now that we're throwing outside of the suspend the FlowLogic can catch it. We need Quasar to terminate + // the fiber when exceptions occur inside a suspend. + exceptionDuringSuspend?.let { throw it } + logger.trace { "Resumed from $ioRequest" } } internal fun resume(scheduler: FiberScheduler) { diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt index d6ebf752b1..98be2c855e 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt @@ -2,6 +2,7 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Suspendable +import co.paralleluniverse.strands.Strand.UncaughtExceptionHandler import com.google.common.util.concurrent.ListenableFuture import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.issuedBy @@ -10,7 +11,9 @@ import net.corda.core.crypto.generateKeyPair import net.corda.core.flows.FlowException import net.corda.core.flows.FlowLogic import net.corda.core.getOrThrow +import net.corda.core.map import net.corda.core.random63BitValue +import net.corda.core.rootCause import net.corda.core.serialization.OpaqueBytes import net.corda.core.serialization.deserialize import net.corda.flows.CashCommand @@ -27,7 +30,7 @@ import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork.MockNode import net.corda.testing.sequence import org.assertj.core.api.Assertions.assertThat -import org.assertj.core.api.Assertions.assertThatThrownBy +import org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy import org.junit.After import org.junit.Before import org.junit.Test @@ -84,9 +87,34 @@ class StateMachineManagerTests { assertThat(flow.lazyTime).isNotNull() } + @Test + fun `exception while fiber suspended`() { + node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(2, it) } + val flow = ReceiveFlow(node2.info.legalIdentity) + val fiber = node1.services.startFlow(flow) as FlowStateMachineImpl + // Before the flow runs change the suspend action to throw an exception + val exceptionDuringSuspend = Exception("Thrown during suspend") + fiber.actionOnSuspend = { + throw exceptionDuringSuspend + } + var uncaughtException: Throwable? = null + fiber.uncaughtExceptionHandler = UncaughtExceptionHandler { f, e -> + uncaughtException = e + } + net.runNetwork() + assertThatThrownBy { + fiber.resultFuture.getOrThrow() + }.isSameAs(exceptionDuringSuspend) + assertThat(node1.smm.allStateMachines).isEmpty() + // Make sure it doesn't get swallowed up + assertThat(uncaughtException?.rootCause).isSameAs(exceptionDuringSuspend) + // Make sure the fiber does actually terminate + assertThat(fiber.isTerminated).isTrue() + } + @Test fun `flow restarted just after receiving payload`() { - node2.services.registerFlowInitiator(SendFlow::class) { ReceiveThenSuspendFlow(it) } + node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() } val payload = random63BitValue() node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity)) @@ -96,7 +124,7 @@ class StateMachineManagerTests { node2.acceptableLiveFiberCountOnStop = 1 node2.stop() net.runNetwork() - val restoredFlow = node2.restartAndGetRestoredFlow(node1) + val restoredFlow = node2.restartAndGetRestoredFlow(node1) assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload) } @@ -138,13 +166,13 @@ class StateMachineManagerTests { @Test fun `flow loaded from checkpoint will respond to messages from before start`() { val payload = random63BitValue() - node1.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { SendFlow(payload, it) } - node2.services.startFlow(ReceiveThenSuspendFlow(node1.info.legalIdentity)) // Prepare checkpointed receive flow + node1.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(payload, it) } + node2.services.startFlow(ReceiveFlow(node1.info.legalIdentity).nonTerminating()) // Prepare checkpointed receive flow // Make sure the add() has finished initial processing. node2.smm.executor.flush() node2.disableDBCloseOnStop() node2.stop() // kill receiver - val restoredFlow = node2.restartAndGetRestoredFlow(node1) + val restoredFlow = node2.restartAndGetRestoredFlow(node1) assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload) } @@ -202,13 +230,13 @@ class StateMachineManagerTests { fun `sending to multiple parties`() { val node3 = net.createNode(node1.info.address) net.runNetwork() - node2.services.registerFlowInitiator(SendFlow::class) { ReceiveThenSuspendFlow(it) } - node3.services.registerFlowInitiator(SendFlow::class) { ReceiveThenSuspendFlow(it) } + node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() } + node3.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() } val payload = random63BitValue() node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity)) net.runNetwork() - val node2Flow = node2.getSingleFlow().first - val node3Flow = node3.getSingleFlow().first + val node2Flow = node2.getSingleFlow().first + val node3Flow = node3.getSingleFlow().first assertThat(node2Flow.receivedPayloads[0]).isEqualTo(payload) assertThat(node3Flow.receivedPayloads[0]).isEqualTo(payload) @@ -236,9 +264,9 @@ class StateMachineManagerTests { net.runNetwork() val node2Payload = random63BitValue() val node3Payload = random63BitValue() - node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { SendFlow(node2Payload, it) } - node3.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { SendFlow(node3Payload, it) } - val multiReceiveFlow = ReceiveThenSuspendFlow(node2.info.legalIdentity, node3.info.legalIdentity) + node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node2Payload, it) } + node3.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node3Payload, it) } + val multiReceiveFlow = ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity).nonTerminating() node1.services.startFlow(multiReceiveFlow) node1.acceptableLiveFiberCountOnStop = 1 net.runNetwork() @@ -246,14 +274,14 @@ class StateMachineManagerTests { assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(node3Payload) assertSessionTransfers(node2, - node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2, + node1 sent sessionInit(ReceiveFlow::class) to node2, node2 sent sessionConfirm to node1, node2 sent sessionData(node2Payload) to node1, node2 sent sessionEnd to node1 ) assertSessionTransfers(node3, - node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node3, + node1 sent sessionInit(ReceiveFlow::class) to node3, node3 sent sessionConfirm to node1, node3 sent sessionData(node3Payload) to node1, node3 sent sessionEnd to node1 @@ -329,12 +357,14 @@ class StateMachineManagerTests { @Test fun `exception thrown on other side`() { - node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { ExceptionFlow } - val future = node1.services.startFlow(ReceiveThenSuspendFlow(node2.info.legalIdentity)).resultFuture + val erroringFiber = node2.initiateSingleShotFlow(ReceiveFlow::class) { ExceptionFlow }.map { it.stateMachine as FlowStateMachineImpl } + val receivingFiber = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)) as FlowStateMachineImpl net.runNetwork() - assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowException::class.java) + assertThatThrownBy { receivingFiber.resultFuture.getOrThrow() }.isInstanceOf(FlowException::class.java) + assertThat(receivingFiber.isTerminated).isTrue() + assertThat(erroringFiber.getOrThrow().isTerminated).isTrue() assertSessionTransfers( - node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2, + node1 sent sessionInit(ReceiveFlow::class) to node2, node2 sent sessionConfirm to node1, node2 sent sessionEnd to node1 ) @@ -433,7 +463,9 @@ class StateMachineManagerTests { } - private class ReceiveThenSuspendFlow(vararg val otherParties: Party) : FlowLogic() { + private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic() { + private var nonTerminating: Boolean = false + init { require(otherParties.isNotEmpty()) } @@ -443,7 +475,14 @@ class StateMachineManagerTests { @Suspendable override fun call() { receivedPayloads = otherParties.map { receive(it).unwrap { it } } - Fiber.park() + if (nonTerminating) { + Fiber.park() + } + } + + fun nonTerminating(): ReceiveFlow { + nonTerminating = true + return this } }