mirror of
https://github.com/corda/corda.git
synced 2025-05-31 14:40:52 +00:00
Making sure exceptions thrown while fibers are suspended are handled properly
This commit is contained in:
parent
97ca6e7d8b
commit
052a660c1b
@ -30,7 +30,7 @@ abstract class FlowLogic<out T> {
|
|||||||
val logger: Logger get() = stateMachine.logger
|
val logger: Logger get() = stateMachine.logger
|
||||||
|
|
||||||
/** Returns a wrapped [UUID] object that identifies this state machine run (i.e. subflows have the same identifier as their parents). */
|
/** Returns a wrapped [UUID] object that identifies this state machine run (i.e. subflows have the same identifier as their parents). */
|
||||||
val runId: StateMachineRunId get() = sessionFlow.stateMachine.id
|
val runId: StateMachineRunId get() = stateMachine.id
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is
|
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is
|
||||||
|
@ -6,23 +6,14 @@ import co.paralleluniverse.fibers.Suspendable
|
|||||||
import co.paralleluniverse.strands.Strand
|
import co.paralleluniverse.strands.Strand
|
||||||
import com.google.common.util.concurrent.ListenableFuture
|
import com.google.common.util.concurrent.ListenableFuture
|
||||||
import com.google.common.util.concurrent.SettableFuture
|
import com.google.common.util.concurrent.SettableFuture
|
||||||
import net.corda.core.contracts.ContractState
|
|
||||||
import net.corda.core.contracts.StateRef
|
|
||||||
import net.corda.core.contracts.TransactionState
|
|
||||||
import net.corda.core.crypto.Party
|
import net.corda.core.crypto.Party
|
||||||
import net.corda.core.flows.FlowException
|
import net.corda.core.flows.FlowException
|
||||||
import net.corda.core.flows.FlowLogic
|
import net.corda.core.flows.FlowLogic
|
||||||
import net.corda.core.flows.FlowStateMachine
|
import net.corda.core.flows.FlowStateMachine
|
||||||
import net.corda.core.flows.StateMachineRunId
|
import net.corda.core.flows.StateMachineRunId
|
||||||
import net.corda.core.node.ServiceHub
|
|
||||||
import net.corda.core.random63BitValue
|
import net.corda.core.random63BitValue
|
||||||
import net.corda.core.transactions.LedgerTransaction
|
|
||||||
import net.corda.core.transactions.SignedTransaction
|
|
||||||
import net.corda.core.utilities.UntrustworthyData
|
import net.corda.core.utilities.UntrustworthyData
|
||||||
import net.corda.core.utilities.trace
|
import net.corda.core.utilities.trace
|
||||||
import net.corda.flows.BroadcastTransactionFlow
|
|
||||||
import net.corda.flows.FinalityFlow
|
|
||||||
import net.corda.flows.ResolveTransactionsFlow
|
|
||||||
import net.corda.node.services.api.ServiceHubInternal
|
import net.corda.node.services.api.ServiceHubInternal
|
||||||
import net.corda.node.services.statemachine.StateMachineManager.FlowSession
|
import net.corda.node.services.statemachine.StateMachineManager.FlowSession
|
||||||
import net.corda.node.services.statemachine.StateMachineManager.FlowSessionState
|
import net.corda.node.services.statemachine.StateMachineManager.FlowSessionState
|
||||||
@ -55,12 +46,12 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// These fields shouldn't be serialised, so they are marked @Transient.
|
// These fields shouldn't be serialised, so they are marked @Transient.
|
||||||
@Transient lateinit override var serviceHub: ServiceHubInternal
|
@Transient override lateinit var serviceHub: ServiceHubInternal
|
||||||
|
@Transient internal lateinit var database: Database
|
||||||
@Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit
|
@Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit
|
||||||
@Transient internal lateinit var actionOnEnd: () -> Unit
|
@Transient internal lateinit var actionOnEnd: () -> Unit
|
||||||
@Transient internal lateinit var database: Database
|
|
||||||
@Transient internal var fromCheckpoint: Boolean = false
|
@Transient internal var fromCheckpoint: Boolean = false
|
||||||
@Transient internal var txTrampoline: Transaction? = null
|
@Transient private var txTrampoline: Transaction? = null
|
||||||
|
|
||||||
@Transient private var _logger: Logger? = null
|
@Transient private var _logger: Logger? = null
|
||||||
override val logger: Logger get() {
|
override val logger: Logger get() {
|
||||||
@ -255,30 +246,28 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
txTrampoline = TransactionManager.currentOrNull()
|
txTrampoline = TransactionManager.currentOrNull()
|
||||||
StrandLocalTransactionManager.setThreadLocalTx(null)
|
StrandLocalTransactionManager.setThreadLocalTx(null)
|
||||||
ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>)
|
ioRequest.session.waitingForResponse = (ioRequest is ReceiveRequest<*>)
|
||||||
|
|
||||||
|
var exceptionDuringSuspend: Throwable? = null
|
||||||
parkAndSerialize { fiber, serializer ->
|
parkAndSerialize { fiber, serializer ->
|
||||||
logger.trace { "Suspended on $ioRequest" }
|
logger.trace { "Suspended on $ioRequest" }
|
||||||
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
|
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
|
||||||
StrandLocalTransactionManager.setThreadLocalTx(txTrampoline)
|
|
||||||
txTrampoline = null
|
|
||||||
try {
|
try {
|
||||||
|
StrandLocalTransactionManager.setThreadLocalTx(txTrampoline)
|
||||||
|
txTrampoline = null
|
||||||
actionOnSuspend(ioRequest)
|
actionOnSuspend(ioRequest)
|
||||||
} catch (t: Throwable) {
|
} catch (t: Throwable) {
|
||||||
// Do not throw exception again - Quasar completely bins it.
|
// Quasar does not terminate the fiber properly if an exception occurs during a suspend. We have to
|
||||||
logger.warn("Captured exception which was swallowed by Quasar for $logic at ${fiber.stackTrace.toList().joinToString("\n")}", t)
|
// resume the fiber just so that we can throw it when it's running.
|
||||||
// TODO When error handling is introduced, look into whether we should be deleting the checkpoint and
|
exceptionDuringSuspend = t
|
||||||
// completing the Future
|
resume(scheduler)
|
||||||
processException(t)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.trace { "Resumed from $ioRequest" }
|
|
||||||
createTransaction()
|
|
||||||
}
|
|
||||||
|
|
||||||
private fun processException(t: Throwable) {
|
createTransaction()
|
||||||
// This can get called in actionOnSuspend *after* we commit the database transaction, so optionally open a new one here.
|
// TODO Now that we're throwing outside of the suspend the FlowLogic can catch it. We need Quasar to terminate
|
||||||
createDatabaseTransaction(database)
|
// the fiber when exceptions occur inside a suspend.
|
||||||
actionOnEnd()
|
exceptionDuringSuspend?.let { throw it }
|
||||||
_resultFuture?.setException(t)
|
logger.trace { "Resumed from $ioRequest" }
|
||||||
}
|
}
|
||||||
|
|
||||||
internal fun resume(scheduler: FiberScheduler) {
|
internal fun resume(scheduler: FiberScheduler) {
|
||||||
|
@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
|
|||||||
|
|
||||||
import co.paralleluniverse.fibers.Fiber
|
import co.paralleluniverse.fibers.Fiber
|
||||||
import co.paralleluniverse.fibers.Suspendable
|
import co.paralleluniverse.fibers.Suspendable
|
||||||
|
import co.paralleluniverse.strands.Strand.UncaughtExceptionHandler
|
||||||
import com.google.common.util.concurrent.ListenableFuture
|
import com.google.common.util.concurrent.ListenableFuture
|
||||||
import net.corda.core.contracts.DOLLARS
|
import net.corda.core.contracts.DOLLARS
|
||||||
import net.corda.core.contracts.issuedBy
|
import net.corda.core.contracts.issuedBy
|
||||||
@ -10,7 +11,9 @@ import net.corda.core.crypto.generateKeyPair
|
|||||||
import net.corda.core.flows.FlowException
|
import net.corda.core.flows.FlowException
|
||||||
import net.corda.core.flows.FlowLogic
|
import net.corda.core.flows.FlowLogic
|
||||||
import net.corda.core.getOrThrow
|
import net.corda.core.getOrThrow
|
||||||
|
import net.corda.core.map
|
||||||
import net.corda.core.random63BitValue
|
import net.corda.core.random63BitValue
|
||||||
|
import net.corda.core.rootCause
|
||||||
import net.corda.core.serialization.OpaqueBytes
|
import net.corda.core.serialization.OpaqueBytes
|
||||||
import net.corda.core.serialization.deserialize
|
import net.corda.core.serialization.deserialize
|
||||||
import net.corda.flows.CashCommand
|
import net.corda.flows.CashCommand
|
||||||
@ -27,7 +30,7 @@ import net.corda.testing.node.MockNetwork
|
|||||||
import net.corda.testing.node.MockNetwork.MockNode
|
import net.corda.testing.node.MockNetwork.MockNode
|
||||||
import net.corda.testing.sequence
|
import net.corda.testing.sequence
|
||||||
import org.assertj.core.api.Assertions.assertThat
|
import org.assertj.core.api.Assertions.assertThat
|
||||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
import org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy
|
||||||
import org.junit.After
|
import org.junit.After
|
||||||
import org.junit.Before
|
import org.junit.Before
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
@ -84,9 +87,34 @@ class StateMachineManagerTests {
|
|||||||
assertThat(flow.lazyTime).isNotNull()
|
assertThat(flow.lazyTime).isNotNull()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `exception while fiber suspended`() {
|
||||||
|
node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(2, it) }
|
||||||
|
val flow = ReceiveFlow(node2.info.legalIdentity)
|
||||||
|
val fiber = node1.services.startFlow(flow) as FlowStateMachineImpl
|
||||||
|
// Before the flow runs change the suspend action to throw an exception
|
||||||
|
val exceptionDuringSuspend = Exception("Thrown during suspend")
|
||||||
|
fiber.actionOnSuspend = {
|
||||||
|
throw exceptionDuringSuspend
|
||||||
|
}
|
||||||
|
var uncaughtException: Throwable? = null
|
||||||
|
fiber.uncaughtExceptionHandler = UncaughtExceptionHandler { f, e ->
|
||||||
|
uncaughtException = e
|
||||||
|
}
|
||||||
|
net.runNetwork()
|
||||||
|
assertThatThrownBy {
|
||||||
|
fiber.resultFuture.getOrThrow()
|
||||||
|
}.isSameAs(exceptionDuringSuspend)
|
||||||
|
assertThat(node1.smm.allStateMachines).isEmpty()
|
||||||
|
// Make sure it doesn't get swallowed up
|
||||||
|
assertThat(uncaughtException?.rootCause).isSameAs(exceptionDuringSuspend)
|
||||||
|
// Make sure the fiber does actually terminate
|
||||||
|
assertThat(fiber.isTerminated).isTrue()
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `flow restarted just after receiving payload`() {
|
fun `flow restarted just after receiving payload`() {
|
||||||
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveThenSuspendFlow(it) }
|
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
|
||||||
val payload = random63BitValue()
|
val payload = random63BitValue()
|
||||||
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity))
|
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity))
|
||||||
|
|
||||||
@ -96,7 +124,7 @@ class StateMachineManagerTests {
|
|||||||
node2.acceptableLiveFiberCountOnStop = 1
|
node2.acceptableLiveFiberCountOnStop = 1
|
||||||
node2.stop()
|
node2.stop()
|
||||||
net.runNetwork()
|
net.runNetwork()
|
||||||
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveThenSuspendFlow>(node1)
|
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveFlow>(node1)
|
||||||
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload)
|
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,13 +166,13 @@ class StateMachineManagerTests {
|
|||||||
@Test
|
@Test
|
||||||
fun `flow loaded from checkpoint will respond to messages from before start`() {
|
fun `flow loaded from checkpoint will respond to messages from before start`() {
|
||||||
val payload = random63BitValue()
|
val payload = random63BitValue()
|
||||||
node1.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { SendFlow(payload, it) }
|
node1.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(payload, it) }
|
||||||
node2.services.startFlow(ReceiveThenSuspendFlow(node1.info.legalIdentity)) // Prepare checkpointed receive flow
|
node2.services.startFlow(ReceiveFlow(node1.info.legalIdentity).nonTerminating()) // Prepare checkpointed receive flow
|
||||||
// Make sure the add() has finished initial processing.
|
// Make sure the add() has finished initial processing.
|
||||||
node2.smm.executor.flush()
|
node2.smm.executor.flush()
|
||||||
node2.disableDBCloseOnStop()
|
node2.disableDBCloseOnStop()
|
||||||
node2.stop() // kill receiver
|
node2.stop() // kill receiver
|
||||||
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveThenSuspendFlow>(node1)
|
val restoredFlow = node2.restartAndGetRestoredFlow<ReceiveFlow>(node1)
|
||||||
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload)
|
assertThat(restoredFlow.receivedPayloads[0]).isEqualTo(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -202,13 +230,13 @@ class StateMachineManagerTests {
|
|||||||
fun `sending to multiple parties`() {
|
fun `sending to multiple parties`() {
|
||||||
val node3 = net.createNode(node1.info.address)
|
val node3 = net.createNode(node1.info.address)
|
||||||
net.runNetwork()
|
net.runNetwork()
|
||||||
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveThenSuspendFlow(it) }
|
node2.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
|
||||||
node3.services.registerFlowInitiator(SendFlow::class) { ReceiveThenSuspendFlow(it) }
|
node3.services.registerFlowInitiator(SendFlow::class) { ReceiveFlow(it).nonTerminating() }
|
||||||
val payload = random63BitValue()
|
val payload = random63BitValue()
|
||||||
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity))
|
node1.services.startFlow(SendFlow(payload, node2.info.legalIdentity, node3.info.legalIdentity))
|
||||||
net.runNetwork()
|
net.runNetwork()
|
||||||
val node2Flow = node2.getSingleFlow<ReceiveThenSuspendFlow>().first
|
val node2Flow = node2.getSingleFlow<ReceiveFlow>().first
|
||||||
val node3Flow = node3.getSingleFlow<ReceiveThenSuspendFlow>().first
|
val node3Flow = node3.getSingleFlow<ReceiveFlow>().first
|
||||||
assertThat(node2Flow.receivedPayloads[0]).isEqualTo(payload)
|
assertThat(node2Flow.receivedPayloads[0]).isEqualTo(payload)
|
||||||
assertThat(node3Flow.receivedPayloads[0]).isEqualTo(payload)
|
assertThat(node3Flow.receivedPayloads[0]).isEqualTo(payload)
|
||||||
|
|
||||||
@ -236,9 +264,9 @@ class StateMachineManagerTests {
|
|||||||
net.runNetwork()
|
net.runNetwork()
|
||||||
val node2Payload = random63BitValue()
|
val node2Payload = random63BitValue()
|
||||||
val node3Payload = random63BitValue()
|
val node3Payload = random63BitValue()
|
||||||
node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { SendFlow(node2Payload, it) }
|
node2.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node2Payload, it) }
|
||||||
node3.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { SendFlow(node3Payload, it) }
|
node3.services.registerFlowInitiator(ReceiveFlow::class) { SendFlow(node3Payload, it) }
|
||||||
val multiReceiveFlow = ReceiveThenSuspendFlow(node2.info.legalIdentity, node3.info.legalIdentity)
|
val multiReceiveFlow = ReceiveFlow(node2.info.legalIdentity, node3.info.legalIdentity).nonTerminating()
|
||||||
node1.services.startFlow(multiReceiveFlow)
|
node1.services.startFlow(multiReceiveFlow)
|
||||||
node1.acceptableLiveFiberCountOnStop = 1
|
node1.acceptableLiveFiberCountOnStop = 1
|
||||||
net.runNetwork()
|
net.runNetwork()
|
||||||
@ -246,14 +274,14 @@ class StateMachineManagerTests {
|
|||||||
assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(node3Payload)
|
assertThat(multiReceiveFlow.receivedPayloads[1]).isEqualTo(node3Payload)
|
||||||
|
|
||||||
assertSessionTransfers(node2,
|
assertSessionTransfers(node2,
|
||||||
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
|
node1 sent sessionInit(ReceiveFlow::class) to node2,
|
||||||
node2 sent sessionConfirm to node1,
|
node2 sent sessionConfirm to node1,
|
||||||
node2 sent sessionData(node2Payload) to node1,
|
node2 sent sessionData(node2Payload) to node1,
|
||||||
node2 sent sessionEnd to node1
|
node2 sent sessionEnd to node1
|
||||||
)
|
)
|
||||||
|
|
||||||
assertSessionTransfers(node3,
|
assertSessionTransfers(node3,
|
||||||
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node3,
|
node1 sent sessionInit(ReceiveFlow::class) to node3,
|
||||||
node3 sent sessionConfirm to node1,
|
node3 sent sessionConfirm to node1,
|
||||||
node3 sent sessionData(node3Payload) to node1,
|
node3 sent sessionData(node3Payload) to node1,
|
||||||
node3 sent sessionEnd to node1
|
node3 sent sessionEnd to node1
|
||||||
@ -329,12 +357,14 @@ class StateMachineManagerTests {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `exception thrown on other side`() {
|
fun `exception thrown on other side`() {
|
||||||
node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { ExceptionFlow }
|
val erroringFiber = node2.initiateSingleShotFlow(ReceiveFlow::class) { ExceptionFlow }.map { it.stateMachine as FlowStateMachineImpl }
|
||||||
val future = node1.services.startFlow(ReceiveThenSuspendFlow(node2.info.legalIdentity)).resultFuture
|
val receivingFiber = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)) as FlowStateMachineImpl
|
||||||
net.runNetwork()
|
net.runNetwork()
|
||||||
assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowException::class.java)
|
assertThatThrownBy { receivingFiber.resultFuture.getOrThrow() }.isInstanceOf(FlowException::class.java)
|
||||||
|
assertThat(receivingFiber.isTerminated).isTrue()
|
||||||
|
assertThat(erroringFiber.getOrThrow().isTerminated).isTrue()
|
||||||
assertSessionTransfers(
|
assertSessionTransfers(
|
||||||
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
|
node1 sent sessionInit(ReceiveFlow::class) to node2,
|
||||||
node2 sent sessionConfirm to node1,
|
node2 sent sessionConfirm to node1,
|
||||||
node2 sent sessionEnd to node1
|
node2 sent sessionEnd to node1
|
||||||
)
|
)
|
||||||
@ -433,7 +463,9 @@ class StateMachineManagerTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private class ReceiveThenSuspendFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
|
private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
|
||||||
|
private var nonTerminating: Boolean = false
|
||||||
|
|
||||||
init {
|
init {
|
||||||
require(otherParties.isNotEmpty())
|
require(otherParties.isNotEmpty())
|
||||||
}
|
}
|
||||||
@ -443,7 +475,14 @@ class StateMachineManagerTests {
|
|||||||
@Suspendable
|
@Suspendable
|
||||||
override fun call() {
|
override fun call() {
|
||||||
receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } }
|
receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } }
|
||||||
Fiber.park()
|
if (nonTerminating) {
|
||||||
|
Fiber.park()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun nonTerminating(): ReceiveFlow {
|
||||||
|
nonTerminating = true
|
||||||
|
return this
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user