From 7f3458803cc12b62b97c652f69f41a1c30ff854f Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Thu, 16 Jun 2016 16:00:13 +0100 Subject: [PATCH] Protocols can use the serviceHub lazily in their constructors --- .../r3corda/core/protocols/ProtocolLogic.kt | 6 +++- .../statemachine/ProtocolStateMachineImpl.kt | 17 ++++------ .../statemachine/StateMachineManager.kt | 31 ++++++++++--------- .../statemachine/StateMachineManagerTests.kt | 19 +++++++++++- 4 files changed, 45 insertions(+), 28 deletions(-) diff --git a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt index b28c857cde..d602e040ef 100644 --- a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt +++ b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogic.kt @@ -31,7 +31,11 @@ abstract class ProtocolLogic { /** This is where you should log things to. */ val logger: Logger get() = psm.logger - /** Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts */ + /** + * Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is + * only available once the protocol has started, which means it cannnot be accessed in the constructor. Either + * access this lazily or from inside [call]. + */ val serviceHub: ServiceHub get() = psm.serviceHub // Kotlin helpers that allow the use of generic types. diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt index 7c8881ef75..a6850e0f5c 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt @@ -21,13 +21,16 @@ import org.slf4j.LoggerFactory * a protocol invokes a sub-protocol, then it will pass along the PSM to the child. The call method of the topmost * logic element gets to return the value that the entire state machine resolves to. */ -class ProtocolStateMachineImpl(val logic: ProtocolLogic, scheduler: FiberScheduler, private val loggerName: String) : Fiber("protocol", scheduler), ProtocolStateMachine { +class ProtocolStateMachineImpl(val logic: ProtocolLogic, + scheduler: FiberScheduler, + private val loggerName: String) +: Fiber("protocol", scheduler), ProtocolStateMachine { // These fields shouldn't be serialised, so they are marked @Transient. - @Transient private var suspendAction: ((result: StateMachineManager.FiberRequest) -> Unit)? = null - @Transient private var receivedPayload: Any? = null @Transient lateinit override var serviceHub: ServiceHubInternal + @Transient internal lateinit var suspendAction: (StateMachineManager.FiberRequest) -> Unit @Transient internal lateinit var actionOnEnd: () -> Unit + @Transient internal var receivedPayload: Any? = null @Transient private var _logger: Logger? = null override val logger: Logger get() { @@ -52,14 +55,6 @@ class ProtocolStateMachineImpl(val logic: ProtocolLogic, scheduler: FiberS logic.psm = this } - fun prepareForResumeWith(serviceHub: ServiceHubInternal, - receivedPayload: Any?, - suspendAction: (StateMachineManager.FiberRequest) -> Unit) { - this.serviceHub = serviceHub - this.receivedPayload = receivedPayload - this.suspendAction = suspendAction - } - @Suspendable @Suppress("UNCHECKED_CAST") override fun run(): R { val result = try { diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt index ca40b18565..2616585dec 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt @@ -6,16 +6,13 @@ import co.paralleluniverse.io.serialization.kryo.KryoSerializer import com.codahale.metrics.Gauge import com.esotericsoftware.kryo.Kryo import com.google.common.base.Throwables -import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture import com.r3corda.core.abbreviate import com.r3corda.core.messaging.MessageRecipients import com.r3corda.core.messaging.runOnNextMessage import com.r3corda.core.messaging.send import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.serialization.* -import com.r3corda.core.then import com.r3corda.core.utilities.ProgressTracker import com.r3corda.core.utilities.trace import com.r3corda.node.services.api.Checkpoint @@ -115,12 +112,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService } fun start() { - checkpointStorage.checkpoints.forEach { restoreCheckpoint(it) } + checkpointStorage.checkpoints.forEach { restoreFromCheckpoint(it) } } - private fun restoreCheckpoint(checkpoint: Checkpoint) { + private fun restoreFromCheckpoint(checkpoint: Checkpoint) { val fiber = deserializeFiber(checkpoint.serialisedFiber) - initFiber(fiber, checkpoint) + initFiber(fiber, { checkpoint }) val topic = checkpoint.awaitingTopic if (topic != null) { @@ -177,8 +174,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService } } - private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) { - stateMachines[psm] = checkpoint + private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint) { + psm.serviceHub = serviceHub + psm.suspendAction = { request -> + psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" } + onNextSuspend(psm, request) + } psm.actionOnEnd = { psm.logic.progressTracker?.currentStep = ProgressTracker.DONE val finalCheckpoint = stateMachines.remove(psm) @@ -188,6 +189,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService totalFinishedProtocols.inc() notifyChangeObservers(psm, AddOrRemove.REMOVE) } + stateMachines[psm] = startingCheckpoint() notifyChangeObservers(psm, AddOrRemove.ADD) } @@ -199,10 +201,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService fun add(loggerName: String, logic: ProtocolLogic): ListenableFuture { try { val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName) - val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null) - checkpointStorage.addCheckpoint(checkpoint) // Need to add before iterating in case of immediate completion - initFiber(fiber, checkpoint) + initFiber(fiber) { + val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null) + checkpointStorage.addCheckpoint(checkpoint) + checkpoint + } executor.executeASAP { iterateStateMachine(fiber, null) { fiber.start() @@ -234,10 +238,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService receivedPayload: Any?, resumeAction: (Any?) -> Unit) { executor.checkOnThread() - psm.prepareForResumeWith(serviceHub, receivedPayload) { request -> - psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" } - onNextSuspend(psm, request) - } + psm.receivedPayload = receivedPayload psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" } resumeAction(receivedPayload) } diff --git a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt index eb6552d791..de1e288d11 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt @@ -27,7 +27,7 @@ class StateMachineManagerTests { @Test fun `newly added protocol is preserved on restart`() { - smm.add("mock", ProtocolWithoutCheckpoints()) + smm.add("test", ProtocolWithoutCheckpoints()) // Ensure we're restoring from the original add checkpoint assertThat(checkpointStorage.allCheckpoints).hasSize(1) val restoredProtocol = createManager().run { @@ -37,11 +37,19 @@ class StateMachineManagerTests { assertThat(restoredProtocol.protocolStarted).isTrue() } + @Test + fun `protocol can lazily use the serviceHub in its constructor`() { + val protocol = ProtocolWithLazyServiceHub() + smm.add("test", protocol) + assertThat(protocol.lazyTime).isNotNull() + } + private fun createManager() = StateMachineManager(object : MockServices() { override val networkService: MessagingService get() = network }, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD) + private class ProtocolWithoutCheckpoints : ProtocolLogic() { @Transient var protocolStarted = false @@ -54,6 +62,15 @@ class StateMachineManagerTests { } + private class ProtocolWithLazyServiceHub : ProtocolLogic() { + + val lazyTime by lazy { serviceHub.clock.instant() } + + @Suspendable + override fun call() {} + } + + class RecordingCheckpointStorage : CheckpointStorage { private val _checkpoints = ArrayList()