diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt index 461c831e66..38cec65a3a 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt @@ -27,6 +27,7 @@ class ProtocolStateMachineImpl(val logic: ProtocolLogic, scheduler: FiberS @Transient private var suspendAction: ((result: StateMachineManager.FiberRequest, fiber: ProtocolStateMachineImpl<*>) -> Unit)? = null @Transient private var receivedPayload: Any? = null @Transient lateinit override var serviceHub: ServiceHubInternal + @Transient internal lateinit var actionOnEnd: () -> Unit @Transient private var _logger: Logger? = null override val logger: Logger get() { @@ -61,15 +62,18 @@ class ProtocolStateMachineImpl(val logic: ProtocolLogic, scheduler: FiberS @Suspendable @Suppress("UNCHECKED_CAST") override fun run(): R { - try { - val result = logic.call() - if (result != null) - _resultFuture?.set(result) - return result - } catch (e: Throwable) { - _resultFuture?.setException(e) - throw e + val result = try { + logic.call() + } catch (t: Throwable) { + actionOnEnd() + _resultFuture?.setException(t) + throw t } + + // This is to prevent actionOnEnd being called twice if it throws an exception + actionOnEnd() + _resultFuture?.set(result) + return result } @Suspendable @Suppress("UNCHECKED_CAST") diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt index ed2e4f36f6..25a9022089 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt @@ -64,9 +64,6 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService // property. private val stateMachines = synchronizedMap(LinkedHashMap, Checkpoint>()) - // A map from fibers to futures that will be completed when the last corresponding checkpoint is removed - private val finalCheckpointRemovedFutures = synchronizedMap(HashMap, SettableFuture>()) - // Monitoring support. private val metrics = serviceHub.monitoringService.metrics @@ -182,19 +179,16 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint?) { stateMachines[psm] = checkpoint - notifyChangeObservers(psm, AddOrRemove.ADD) - val finalCheckpointRemovedFuture: SettableFuture = SettableFuture.create() - finalCheckpointRemovedFutures[psm] = finalCheckpointRemovedFuture - psm.resultFuture.then(executor) { + psm.actionOnEnd = { psm.logic.progressTracker?.currentStep = ProgressTracker.DONE val finalCheckpoint = stateMachines.remove(psm) if (finalCheckpoint != null) { checkpointStorage.removeCheckpoint(finalCheckpoint) } - finalCheckpointRemovedFuture.set(Unit) totalFinishedProtocols.inc() notifyChangeObservers(psm, AddOrRemove.REMOVE) } + notifyChangeObservers(psm, AddOrRemove.ADD) } /** @@ -214,9 +208,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService } totalStartedProtocols.inc() } - val finalCheckpointRemovedFuture = finalCheckpointRemovedFutures.remove(fiber) - return Futures.transformAsync(finalCheckpointRemovedFuture, { fiber.resultFuture }) - } catch(e: Throwable) { + return fiber.resultFuture + } catch (e: Throwable) { e.printStackTrace() throw e }