Protocols can use the serviceHub lazily in their constructors

This commit is contained in:
Shams Asari
2016-06-16 16:00:13 +01:00
parent cb1b274d5c
commit 7f3458803c
4 changed files with 45 additions and 28 deletions

View File

@ -31,7 +31,11 @@ abstract class ProtocolLogic<T> {
/** This is where you should log things to. */ /** This is where you should log things to. */
val logger: Logger get() = psm.logger val logger: Logger get() = psm.logger
/** Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts */ /**
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is
* only available once the protocol has started, which means it cannnot be accessed in the constructor. Either
* access this lazily or from inside [call].
*/
val serviceHub: ServiceHub get() = psm.serviceHub val serviceHub: ServiceHub get() = psm.serviceHub
// Kotlin helpers that allow the use of generic types. // Kotlin helpers that allow the use of generic types.

View File

@ -21,13 +21,16 @@ import org.slf4j.LoggerFactory
* a protocol invokes a sub-protocol, then it will pass along the PSM to the child. The call method of the topmost * a protocol invokes a sub-protocol, then it will pass along the PSM to the child. The call method of the topmost
* logic element gets to return the value that the entire state machine resolves to. * logic element gets to return the value that the entire state machine resolves to.
*/ */
class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberScheduler, private val loggerName: String) : Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> { class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
scheduler: FiberScheduler,
private val loggerName: String)
: Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {
// These fields shouldn't be serialised, so they are marked @Transient. // These fields shouldn't be serialised, so they are marked @Transient.
@Transient private var suspendAction: ((result: StateMachineManager.FiberRequest) -> Unit)? = null
@Transient private var receivedPayload: Any? = null
@Transient lateinit override var serviceHub: ServiceHubInternal @Transient lateinit override var serviceHub: ServiceHubInternal
@Transient internal lateinit var suspendAction: (StateMachineManager.FiberRequest) -> Unit
@Transient internal lateinit var actionOnEnd: () -> Unit @Transient internal lateinit var actionOnEnd: () -> Unit
@Transient internal var receivedPayload: Any? = null
@Transient private var _logger: Logger? = null @Transient private var _logger: Logger? = null
override val logger: Logger get() { override val logger: Logger get() {
@ -52,14 +55,6 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberS
logic.psm = this logic.psm = this
} }
fun prepareForResumeWith(serviceHub: ServiceHubInternal,
receivedPayload: Any?,
suspendAction: (StateMachineManager.FiberRequest) -> Unit) {
this.serviceHub = serviceHub
this.receivedPayload = receivedPayload
this.suspendAction = suspendAction
}
@Suspendable @Suppress("UNCHECKED_CAST") @Suspendable @Suppress("UNCHECKED_CAST")
override fun run(): R { override fun run(): R {
val result = try { val result = try {

View File

@ -6,16 +6,13 @@ import co.paralleluniverse.io.serialization.kryo.KryoSerializer
import com.codahale.metrics.Gauge import com.codahale.metrics.Gauge
import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Kryo
import com.google.common.base.Throwables import com.google.common.base.Throwables
import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.abbreviate import com.r3corda.core.abbreviate
import com.r3corda.core.messaging.MessageRecipients import com.r3corda.core.messaging.MessageRecipients
import com.r3corda.core.messaging.runOnNextMessage import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.messaging.send import com.r3corda.core.messaging.send
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.* import com.r3corda.core.serialization.*
import com.r3corda.core.then
import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.trace import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.Checkpoint import com.r3corda.node.services.api.Checkpoint
@ -115,12 +112,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
} }
fun start() { fun start() {
checkpointStorage.checkpoints.forEach { restoreCheckpoint(it) } checkpointStorage.checkpoints.forEach { restoreFromCheckpoint(it) }
} }
private fun restoreCheckpoint(checkpoint: Checkpoint) { private fun restoreFromCheckpoint(checkpoint: Checkpoint) {
val fiber = deserializeFiber(checkpoint.serialisedFiber) val fiber = deserializeFiber(checkpoint.serialisedFiber)
initFiber(fiber, checkpoint) initFiber(fiber, { checkpoint })
val topic = checkpoint.awaitingTopic val topic = checkpoint.awaitingTopic
if (topic != null) { if (topic != null) {
@ -177,8 +174,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
} }
} }
private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) { private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint) {
stateMachines[psm] = checkpoint psm.serviceHub = serviceHub
psm.suspendAction = { request ->
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
onNextSuspend(psm, request)
}
psm.actionOnEnd = { psm.actionOnEnd = {
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
val finalCheckpoint = stateMachines.remove(psm) val finalCheckpoint = stateMachines.remove(psm)
@ -188,6 +189,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
totalFinishedProtocols.inc() totalFinishedProtocols.inc()
notifyChangeObservers(psm, AddOrRemove.REMOVE) notifyChangeObservers(psm, AddOrRemove.REMOVE)
} }
stateMachines[psm] = startingCheckpoint()
notifyChangeObservers(psm, AddOrRemove.ADD) notifyChangeObservers(psm, AddOrRemove.ADD)
} }
@ -199,10 +201,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> { fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
try { try {
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName) val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
// Need to add before iterating in case of immediate completion
initFiber(fiber) {
val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null) val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null)
checkpointStorage.addCheckpoint(checkpoint) checkpointStorage.addCheckpoint(checkpoint)
// Need to add before iterating in case of immediate completion checkpoint
initFiber(fiber, checkpoint) }
executor.executeASAP { executor.executeASAP {
iterateStateMachine(fiber, null) { iterateStateMachine(fiber, null) {
fiber.start() fiber.start()
@ -234,10 +238,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
receivedPayload: Any?, receivedPayload: Any?,
resumeAction: (Any?) -> Unit) { resumeAction: (Any?) -> Unit) {
executor.checkOnThread() executor.checkOnThread()
psm.prepareForResumeWith(serviceHub, receivedPayload) { request -> psm.receivedPayload = receivedPayload
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
onNextSuspend(psm, request)
}
psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" } psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" }
resumeAction(receivedPayload) resumeAction(receivedPayload)
} }

View File

@ -27,7 +27,7 @@ class StateMachineManagerTests {
@Test @Test
fun `newly added protocol is preserved on restart`() { fun `newly added protocol is preserved on restart`() {
smm.add("mock", ProtocolWithoutCheckpoints()) smm.add("test", ProtocolWithoutCheckpoints())
// Ensure we're restoring from the original add checkpoint // Ensure we're restoring from the original add checkpoint
assertThat(checkpointStorage.allCheckpoints).hasSize(1) assertThat(checkpointStorage.allCheckpoints).hasSize(1)
val restoredProtocol = createManager().run { val restoredProtocol = createManager().run {
@ -37,11 +37,19 @@ class StateMachineManagerTests {
assertThat(restoredProtocol.protocolStarted).isTrue() assertThat(restoredProtocol.protocolStarted).isTrue()
} }
@Test
fun `protocol can lazily use the serviceHub in its constructor`() {
val protocol = ProtocolWithLazyServiceHub()
smm.add("test", protocol)
assertThat(protocol.lazyTime).isNotNull()
}
private fun createManager() = StateMachineManager(object : MockServices() { private fun createManager() = StateMachineManager(object : MockServices() {
override val networkService: MessagingService get() = network override val networkService: MessagingService get() = network
}, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD) }, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD)
private class ProtocolWithoutCheckpoints : ProtocolLogic<Unit>() { private class ProtocolWithoutCheckpoints : ProtocolLogic<Unit>() {
@Transient var protocolStarted = false @Transient var protocolStarted = false
@ -54,6 +62,15 @@ class StateMachineManagerTests {
} }
private class ProtocolWithLazyServiceHub : ProtocolLogic<Unit>() {
val lazyTime by lazy { serviceHub.clock.instant() }
@Suspendable
override fun call() {}
}
class RecordingCheckpointStorage : CheckpointStorage { class RecordingCheckpointStorage : CheckpointStorage {
private val _checkpoints = ArrayList<Checkpoint>() private val _checkpoints = ArrayList<Checkpoint>()