ProgressTracker emits exception thrown by the flow, allowing the ANSI renderer to correctly stop and print the error (#189)

This commit is contained in:
Shams Asari
2017-02-15 10:14:24 +00:00
parent ed093cdb9d
commit f13817efb3
10 changed files with 186 additions and 106 deletions

View File

@ -15,6 +15,7 @@ import net.corda.core.flows.FlowStateMachine
import net.corda.core.flows.StateMachineRunId
import net.corda.core.random63BitValue
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.debug
import net.corda.core.utilities.trace
@ -56,8 +57,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Transient private var _logger: Logger? = null
/**
* Return the logger for this state machine. The logger name incorporates [id] and so including this in the log
* message is not necessary.
* Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message
* is not necessary.
*/
override val logger: Logger get() {
return _logger ?: run {
@ -94,14 +95,12 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} catch (e: FlowException) {
// Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive).
val propagated = e.stackTrace[0].className == javaClass.name
actionOnEnd(e, propagated)
_resultFuture?.setException(e)
processException(e, propagated)
logger.debug(if (propagated) "Flow ended due to receiving exception" else "Flow finished with exception", e)
return
} catch (t: Throwable) {
logger.warn("Terminated by unexpected exception", t)
actionOnEnd(t, false)
_resultFuture?.setException(t)
processException(t, false)
return
}
@ -112,6 +111,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
// This is to prevent actionOnEnd being called twice if it throws an exception
actionOnEnd(null, false)
_resultFuture?.set(result)
logic.progressTracker?.currentStep = ProgressTracker.DONE
logger.debug { "Flow finished with result $result" }
}
@ -121,6 +121,12 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
logger.trace { "Starting database transaction ${TransactionManager.currentOrNull()} on ${Strand.currentStrand()}" }
}
private fun processException(exception: Throwable, propagated: Boolean) {
actionOnEnd(exception, propagated)
_resultFuture?.setException(exception)
logic.progressTracker?.endWithError(exception)
}
internal fun commitTransaction() {
val transaction = TransactionManager.current()
try {

View File

@ -24,7 +24,6 @@ import net.corda.core.messaging.send
import net.corda.core.random63BitValue
import net.corda.core.serialization.*
import net.corda.core.then
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.debug
import net.corda.core.utilities.loggerFor
import net.corda.core.utilities.trace
@ -391,7 +390,6 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
fiber.actionOnEnd = { exception, propagated ->
try {
fiber.logic.progressTracker?.currentStep = ProgressTracker.DONE
mutex.locked {
stateMachines.remove(fiber)?.let { checkpointStorage.removeCheckpoint(it) }
notifyChangeObservers(fiber, AddOrRemove.REMOVE)

View File

@ -2,7 +2,6 @@ package net.corda.node.utilities
import net.corda.core.ThreadBox
import net.corda.core.flows.FlowLogic
import net.corda.core.utilities.ProgressTracker
import net.corda.node.services.statemachine.StateMachineManager
import java.util.*
@ -35,13 +34,12 @@ class ANSIProgressObserver(val smm: StateMachineManager) {
if (currentlyRendering?.progressTracker != null) {
ANSIProgressRenderer.progressTracker = currentlyRendering!!.progressTracker
}
} while (currentlyRendering?.progressTracker?.currentStep == ProgressTracker.DONE)
} while (currentlyRendering?.progressTracker?.hasEnded ?: false)
}
}
private fun removeFlowLogic(flowLogic: FlowLogic<*>) {
state.locked {
flowLogic.progressTracker?.currentStep = ProgressTracker.DONE
if (currentlyRendering == flowLogic) {
wireUpProgressRendering()
}
@ -51,7 +49,7 @@ class ANSIProgressObserver(val smm: StateMachineManager) {
private fun addFlowLogic(flowLogic: FlowLogic<*>) {
state.locked {
pending.add(flowLogic)
if ((currentlyRendering?.progressTracker?.currentStep ?: ProgressTracker.DONE) == ProgressTracker.DONE) {
if (currentlyRendering?.progressTracker?.hasEnded ?: true) {
wireUpProgressRendering()
}
}

View File

@ -1,6 +1,9 @@
package net.corda.node.utilities
import net.corda.core.utilities.Emoji
import net.corda.core.utilities.Emoji.CODE_GREEN_TICK
import net.corda.core.utilities.Emoji.CODE_NO_ENTRY
import net.corda.core.utilities.Emoji.CODE_RIGHT_ARROW
import net.corda.core.utilities.Emoji.SKULL_AND_CROSSBONES
import net.corda.core.utilities.ProgressTracker
import net.corda.node.utilities.ANSIProgressRenderer.progressTracker
import org.apache.logging.log4j.LogManager
@ -43,7 +46,7 @@ object ANSIProgressRenderer {
prevMessagePrinted = null
prevLinesDrawn = 0
draw(true)
subscription = value?.changes?.subscribe { draw(true) }
subscription = value?.changes?.subscribe({ draw(true) }, { draw(true, it) })
}
private fun setup() {
@ -102,7 +105,7 @@ object ANSIProgressRenderer {
// prevLinesDraw is just for ANSI mode.
private var prevLinesDrawn = 0
@Synchronized private fun draw(moveUp: Boolean) {
@Synchronized private fun draw(moveUp: Boolean, error: Throwable? = null) {
val pt = progressTracker!!
if (!usingANSI) {
@ -122,7 +125,15 @@ object ANSIProgressRenderer {
// Put a blank line between any logging and us.
ansi.eraseLine()
ansi.newline()
val newLinesDrawn = 1 + pt.renderLevel(ansi, 0, pt.allSteps)
var newLinesDrawn = 1 + pt.renderLevel(ansi, 0, error != null)
if (error != null) {
ansi.a("$SKULL_AND_CROSSBONES $error")
ansi.eraseLine(Ansi.Erase.FORWARD)
ansi.newline()
newLinesDrawn++
}
if (newLinesDrawn < prevLinesDrawn) {
// If some steps were removed from the progress tracker, we don't want to leave junk hanging around below.
val linesToClear = prevLinesDrawn - newLinesDrawn
@ -140,7 +151,7 @@ object ANSIProgressRenderer {
}
// Returns number of lines rendered.
private fun ProgressTracker.renderLevel(ansi: Ansi, indent: Int, allSteps: List<Pair<Int, ProgressTracker.Step>>): Int {
private fun ProgressTracker.renderLevel(ansi: Ansi, indent: Int, error: Boolean): Int {
with(ansi) {
var lines = 0
for ((index, step) in steps.withIndex()) {
@ -149,10 +160,11 @@ object ANSIProgressRenderer {
if (indent > 0 && step == ProgressTracker.DONE) continue
val marker = when {
index < stepIndex -> Emoji.CODE_GREEN_TICK + " "
index == stepIndex && step == ProgressTracker.DONE -> Emoji.CODE_GREEN_TICK + " "
index == stepIndex -> Emoji.CODE_RIGHT_ARROW + " "
else -> " "
index < stepIndex -> "$CODE_GREEN_TICK "
index == stepIndex && step == ProgressTracker.DONE -> "$CODE_GREEN_TICK "
index == stepIndex -> "$CODE_RIGHT_ARROW "
error -> "$CODE_NO_ENTRY "
else -> " "
}
a(" ".repeat(indent))
a(marker)
@ -168,7 +180,7 @@ object ANSIProgressRenderer {
val child = getChildProgressTracker(step)
if (child != null)
lines += child.renderLevel(ansi, indent + 1, allSteps)
lines += child.renderLevel(ansi, indent + 1, error)
}
return lines
}

View File

@ -3,6 +3,7 @@ package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.*
import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.DummyState
import net.corda.core.contracts.issuedBy
@ -10,17 +11,16 @@ import net.corda.core.crypto.Party
import net.corda.core.crypto.generateKeyPair
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.getOrThrow
import net.corda.core.map
import net.corda.core.messaging.MessageRecipients
import net.corda.core.node.services.PartyInfo
import net.corda.core.node.services.ServiceInfo
import net.corda.core.random63BitValue
import net.corda.core.serialization.OpaqueBytes
import net.corda.core.serialization.deserialize
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.LogHelper
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.ProgressTracker.Change
import net.corda.core.utilities.unwrap
import net.corda.flows.CashIssueFlow
import net.corda.flows.CashPaymentFlow
@ -44,6 +44,7 @@ import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType
import org.junit.After
import org.junit.Before
import org.junit.Test
import rx.Notification
import rx.Observable
import java.util.*
import kotlin.reflect.KClass
@ -379,18 +380,36 @@ class StateMachineManagerTests {
net.runNetwork()
assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy {
resultFuture.getOrThrow()
}.withMessageContaining(String::class.java.name)
}.withMessageContaining(String::class.java.name) // Make sure the exception message mentions the type the flow was expecting to receive
}
@Test
fun `non-FlowException thrown on other side`() {
node2.services.registerFlowInitiator(ReceiveFlow::class) { ExceptionFlow { Exception("evil bug!") } }
val resultFuture = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture
net.runNetwork()
val exceptionResult = assertFailsWith(FlowSessionException::class) {
resultFuture.getOrThrow()
val erroringFlowFuture = node2.initiateSingleShotFlow(ReceiveFlow::class) {
ExceptionFlow { Exception("evil bug!") }
}
assertThat(exceptionResult.message).doesNotContain("evil bug!")
val erroringFlowSteps = erroringFlowFuture.flatMap { it.progressSteps }
val receiveFlow = ReceiveFlow(node2.info.legalIdentity)
val receiveFlowSteps = receiveFlow.progressSteps
val receiveFlowResult = node1.services.startFlow(receiveFlow).resultFuture
net.runNetwork()
assertThat(erroringFlowSteps.get()).containsExactly(
Notification.createOnNext(ExceptionFlow.START_STEP),
Notification.createOnError(erroringFlowFuture.get().exceptionThrown)
)
val receiveFlowException = assertFailsWith(FlowSessionException::class) {
receiveFlowResult.getOrThrow()
}
assertThat(receiveFlowException.message).doesNotContain("evil bug!")
assertThat(receiveFlowSteps.get()).containsExactly(
Notification.createOnNext(ReceiveFlow.START_STEP),
Notification.createOnError(receiveFlowException)
)
assertSessionTransfers(
node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1,
@ -400,11 +419,15 @@ class StateMachineManagerTests {
@Test
fun `FlowException thrown on other side`() {
val erroringFlowFuture = node2.initiateSingleShotFlow(ReceiveFlow::class) {
val erroringFlow = node2.initiateSingleShotFlow(ReceiveFlow::class) {
ExceptionFlow { MyFlowException("Nothing useful") }
}
val erroringFlowSteps = erroringFlow.flatMap { it.progressSteps }
val receivingFiber = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)) as FlowStateMachineImpl
net.runNetwork()
assertThatExceptionOfType(MyFlowException::class.java)
.isThrownBy { receivingFiber.resultFuture.getOrThrow() }
.withMessage("Nothing useful")
@ -412,13 +435,18 @@ class StateMachineManagerTests {
databaseTransaction(node2.database) {
assertThat(node2.checkpointStorage.checkpoints()).isEmpty()
}
val errorFlow = erroringFlowFuture.getOrThrow()
assertThat(receivingFiber.isTerminated).isTrue()
assertThat((errorFlow.stateMachine as FlowStateMachineImpl).isTerminated).isTrue()
assertThat((erroringFlow.get().stateMachine as FlowStateMachineImpl).isTerminated).isTrue()
assertThat(erroringFlowSteps.get()).containsExactly(
Notification.createOnNext(ExceptionFlow.START_STEP),
Notification.createOnError(erroringFlow.get().exceptionThrown)
)
assertSessionTransfers(
node1 sent sessionInit(ReceiveFlow::class) to node2,
node2 sent sessionConfirm to node1,
node2 sent erroredEnd(errorFlow.exceptionThrown) to node1
node2 sent erroredEnd(erroringFlow.get().exceptionThrown) to node1
)
// Make sure the original stack trace isn't sent down the wire
assertThat((sessionTransfers.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty()
@ -606,6 +634,15 @@ class StateMachineManagerTests {
private infix fun MockNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(id, message)
private infix fun Pair<Int, SessionMessage>.to(node: MockNode): SessionTransfer = SessionTransfer(first, second, node.net.myAddress)
private val FlowLogic<*>.progressSteps: ListenableFuture<List<Notification<ProgressTracker.Step>>> get() {
return progressTracker!!.changes
.ofType(Change.Position::class.java)
.map { it.newStep }
.materialize()
.toList()
.toFuture()
}
private class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic<Unit>() {
@Transient var flowStarted = false
@ -630,17 +667,22 @@ class StateMachineManagerTests {
private class ReceiveFlow(vararg val otherParties: Party) : FlowLogic<Unit>() {
private var nonTerminating: Boolean = false
object START_STEP : ProgressTracker.Step("Starting")
object RECEIVED_STEP : ProgressTracker.Step("Received")
init {
require(otherParties.isNotEmpty())
}
override val progressTracker: ProgressTracker = ProgressTracker(START_STEP, RECEIVED_STEP)
private var nonTerminating: Boolean = false
@Transient var receivedPayloads: List<String> = emptyList()
@Suspendable
override fun call() {
progressTracker.currentStep = START_STEP
receivedPayloads = otherParties.map { receive<String>(it).unwrap { it } }
progressTracker.currentStep = RECEIVED_STEP
if (nonTerminating) {
Fiber.park()
}
@ -664,8 +706,13 @@ class StateMachineManagerTests {
}
private class ExceptionFlow<E : Exception>(val exception: () -> E) : FlowLogic<Nothing>() {
object START_STEP : ProgressTracker.Step("Starting")
override val progressTracker: ProgressTracker = ProgressTracker(START_STEP)
lateinit var exceptionThrown: E
override fun call(): Nothing {
progressTracker.currentStep = START_STEP
exceptionThrown = exception()
throw exceptionThrown
}