Simplify last checkpoint removal race condition fix

This commit is contained in:
Shams Asari
2016-06-14 16:10:29 +01:00
parent aa153be6f0
commit 853bc683f8
2 changed files with 16 additions and 19 deletions

View File

@ -27,6 +27,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberS
@Transient private var suspendAction: ((result: StateMachineManager.FiberRequest, fiber: ProtocolStateMachineImpl<*>) -> Unit)? = null @Transient private var suspendAction: ((result: StateMachineManager.FiberRequest, fiber: ProtocolStateMachineImpl<*>) -> Unit)? = null
@Transient private var receivedPayload: Any? = null @Transient private var receivedPayload: Any? = null
@Transient lateinit override var serviceHub: ServiceHubInternal @Transient lateinit override var serviceHub: ServiceHubInternal
@Transient internal lateinit var actionOnEnd: () -> Unit
@Transient private var _logger: Logger? = null @Transient private var _logger: Logger? = null
override val logger: Logger get() { override val logger: Logger get() {
@ -61,15 +62,18 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberS
@Suspendable @Suppress("UNCHECKED_CAST") @Suspendable @Suppress("UNCHECKED_CAST")
override fun run(): R { override fun run(): R {
try { val result = try {
val result = logic.call() logic.call()
if (result != null) } catch (t: Throwable) {
actionOnEnd()
_resultFuture?.setException(t)
throw t
}
// This is to prevent actionOnEnd being called twice if it throws an exception
actionOnEnd()
_resultFuture?.set(result) _resultFuture?.set(result)
return result return result
} catch (e: Throwable) {
_resultFuture?.setException(e)
throw e
}
} }
@Suspendable @Suppress("UNCHECKED_CAST") @Suspendable @Suppress("UNCHECKED_CAST")

View File

@ -64,9 +64,6 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
// property. // property.
private val stateMachines = synchronizedMap(LinkedHashMap<ProtocolStateMachineImpl<*>, Checkpoint>()) private val stateMachines = synchronizedMap(LinkedHashMap<ProtocolStateMachineImpl<*>, Checkpoint>())
// A map from fibers to futures that will be completed when the last corresponding checkpoint is removed
private val finalCheckpointRemovedFutures = synchronizedMap(HashMap<ProtocolStateMachineImpl<*>, SettableFuture<Unit>>())
// Monitoring support. // Monitoring support.
private val metrics = serviceHub.monitoringService.metrics private val metrics = serviceHub.monitoringService.metrics
@ -182,19 +179,16 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint?) { private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint?) {
stateMachines[psm] = checkpoint stateMachines[psm] = checkpoint
notifyChangeObservers(psm, AddOrRemove.ADD) psm.actionOnEnd = {
val finalCheckpointRemovedFuture: SettableFuture<Unit> = SettableFuture.create()
finalCheckpointRemovedFutures[psm] = finalCheckpointRemovedFuture
psm.resultFuture.then(executor) {
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
val finalCheckpoint = stateMachines.remove(psm) val finalCheckpoint = stateMachines.remove(psm)
if (finalCheckpoint != null) { if (finalCheckpoint != null) {
checkpointStorage.removeCheckpoint(finalCheckpoint) checkpointStorage.removeCheckpoint(finalCheckpoint)
} }
finalCheckpointRemovedFuture.set(Unit)
totalFinishedProtocols.inc() totalFinishedProtocols.inc()
notifyChangeObservers(psm, AddOrRemove.REMOVE) notifyChangeObservers(psm, AddOrRemove.REMOVE)
} }
notifyChangeObservers(psm, AddOrRemove.ADD)
} }
/** /**
@ -214,8 +208,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
} }
totalStartedProtocols.inc() totalStartedProtocols.inc()
} }
val finalCheckpointRemovedFuture = finalCheckpointRemovedFutures.remove(fiber) return fiber.resultFuture
return Futures.transformAsync(finalCheckpointRemovedFuture, { fiber.resultFuture })
} catch (e: Throwable) { } catch (e: Throwable) {
e.printStackTrace() e.printStackTrace()
throw e throw e