mirror of
https://github.com/corda/corda.git
synced 2025-02-20 17:33:15 +00:00
Protocols can use the serviceHub lazily in their constructors
This commit is contained in:
parent
cb1b274d5c
commit
7f3458803c
@ -31,7 +31,11 @@ abstract class ProtocolLogic<T> {
|
||||
/** This is where you should log things to. */
|
||||
val logger: Logger get() = psm.logger
|
||||
|
||||
/** Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts */
|
||||
/**
|
||||
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is
|
||||
* only available once the protocol has started, which means it cannnot be accessed in the constructor. Either
|
||||
* access this lazily or from inside [call].
|
||||
*/
|
||||
val serviceHub: ServiceHub get() = psm.serviceHub
|
||||
|
||||
// Kotlin helpers that allow the use of generic types.
|
||||
|
@ -21,13 +21,16 @@ import org.slf4j.LoggerFactory
|
||||
* a protocol invokes a sub-protocol, then it will pass along the PSM to the child. The call method of the topmost
|
||||
* logic element gets to return the value that the entire state machine resolves to.
|
||||
*/
|
||||
class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberScheduler, private val loggerName: String) : Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {
|
||||
class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
|
||||
scheduler: FiberScheduler,
|
||||
private val loggerName: String)
|
||||
: Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {
|
||||
|
||||
// These fields shouldn't be serialised, so they are marked @Transient.
|
||||
@Transient private var suspendAction: ((result: StateMachineManager.FiberRequest) -> Unit)? = null
|
||||
@Transient private var receivedPayload: Any? = null
|
||||
@Transient lateinit override var serviceHub: ServiceHubInternal
|
||||
@Transient internal lateinit var suspendAction: (StateMachineManager.FiberRequest) -> Unit
|
||||
@Transient internal lateinit var actionOnEnd: () -> Unit
|
||||
@Transient internal var receivedPayload: Any? = null
|
||||
|
||||
@Transient private var _logger: Logger? = null
|
||||
override val logger: Logger get() {
|
||||
@ -52,14 +55,6 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberS
|
||||
logic.psm = this
|
||||
}
|
||||
|
||||
fun prepareForResumeWith(serviceHub: ServiceHubInternal,
|
||||
receivedPayload: Any?,
|
||||
suspendAction: (StateMachineManager.FiberRequest) -> Unit) {
|
||||
this.serviceHub = serviceHub
|
||||
this.receivedPayload = receivedPayload
|
||||
this.suspendAction = suspendAction
|
||||
}
|
||||
|
||||
@Suspendable @Suppress("UNCHECKED_CAST")
|
||||
override fun run(): R {
|
||||
val result = try {
|
||||
|
@ -6,16 +6,13 @@ import co.paralleluniverse.io.serialization.kryo.KryoSerializer
|
||||
import com.codahale.metrics.Gauge
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.google.common.base.Throwables
|
||||
import com.google.common.util.concurrent.Futures
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
import com.r3corda.core.abbreviate
|
||||
import com.r3corda.core.messaging.MessageRecipients
|
||||
import com.r3corda.core.messaging.runOnNextMessage
|
||||
import com.r3corda.core.messaging.send
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.serialization.*
|
||||
import com.r3corda.core.then
|
||||
import com.r3corda.core.utilities.ProgressTracker
|
||||
import com.r3corda.core.utilities.trace
|
||||
import com.r3corda.node.services.api.Checkpoint
|
||||
@ -115,12 +112,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
}
|
||||
|
||||
fun start() {
|
||||
checkpointStorage.checkpoints.forEach { restoreCheckpoint(it) }
|
||||
checkpointStorage.checkpoints.forEach { restoreFromCheckpoint(it) }
|
||||
}
|
||||
|
||||
private fun restoreCheckpoint(checkpoint: Checkpoint) {
|
||||
private fun restoreFromCheckpoint(checkpoint: Checkpoint) {
|
||||
val fiber = deserializeFiber(checkpoint.serialisedFiber)
|
||||
initFiber(fiber, checkpoint)
|
||||
initFiber(fiber, { checkpoint })
|
||||
|
||||
val topic = checkpoint.awaitingTopic
|
||||
if (topic != null) {
|
||||
@ -177,8 +174,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
}
|
||||
}
|
||||
|
||||
private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) {
|
||||
stateMachines[psm] = checkpoint
|
||||
private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint) {
|
||||
psm.serviceHub = serviceHub
|
||||
psm.suspendAction = { request ->
|
||||
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
|
||||
onNextSuspend(psm, request)
|
||||
}
|
||||
psm.actionOnEnd = {
|
||||
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||
val finalCheckpoint = stateMachines.remove(psm)
|
||||
@ -188,6 +189,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
totalFinishedProtocols.inc()
|
||||
notifyChangeObservers(psm, AddOrRemove.REMOVE)
|
||||
}
|
||||
stateMachines[psm] = startingCheckpoint()
|
||||
notifyChangeObservers(psm, AddOrRemove.ADD)
|
||||
}
|
||||
|
||||
@ -199,10 +201,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
|
||||
try {
|
||||
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
|
||||
val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null)
|
||||
checkpointStorage.addCheckpoint(checkpoint)
|
||||
// Need to add before iterating in case of immediate completion
|
||||
initFiber(fiber, checkpoint)
|
||||
initFiber(fiber) {
|
||||
val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null)
|
||||
checkpointStorage.addCheckpoint(checkpoint)
|
||||
checkpoint
|
||||
}
|
||||
executor.executeASAP {
|
||||
iterateStateMachine(fiber, null) {
|
||||
fiber.start()
|
||||
@ -234,10 +238,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
receivedPayload: Any?,
|
||||
resumeAction: (Any?) -> Unit) {
|
||||
executor.checkOnThread()
|
||||
psm.prepareForResumeWith(serviceHub, receivedPayload) { request ->
|
||||
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
|
||||
onNextSuspend(psm, request)
|
||||
}
|
||||
psm.receivedPayload = receivedPayload
|
||||
psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" }
|
||||
resumeAction(receivedPayload)
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ class StateMachineManagerTests {
|
||||
|
||||
@Test
|
||||
fun `newly added protocol is preserved on restart`() {
|
||||
smm.add("mock", ProtocolWithoutCheckpoints())
|
||||
smm.add("test", ProtocolWithoutCheckpoints())
|
||||
// Ensure we're restoring from the original add checkpoint
|
||||
assertThat(checkpointStorage.allCheckpoints).hasSize(1)
|
||||
val restoredProtocol = createManager().run {
|
||||
@ -37,11 +37,19 @@ class StateMachineManagerTests {
|
||||
assertThat(restoredProtocol.protocolStarted).isTrue()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `protocol can lazily use the serviceHub in its constructor`() {
|
||||
val protocol = ProtocolWithLazyServiceHub()
|
||||
smm.add("test", protocol)
|
||||
assertThat(protocol.lazyTime).isNotNull()
|
||||
}
|
||||
|
||||
private fun createManager() = StateMachineManager(object : MockServices() {
|
||||
override val networkService: MessagingService get() = network
|
||||
}, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD)
|
||||
|
||||
|
||||
|
||||
private class ProtocolWithoutCheckpoints : ProtocolLogic<Unit>() {
|
||||
|
||||
@Transient var protocolStarted = false
|
||||
@ -54,6 +62,15 @@ class StateMachineManagerTests {
|
||||
}
|
||||
|
||||
|
||||
private class ProtocolWithLazyServiceHub : ProtocolLogic<Unit>() {
|
||||
|
||||
val lazyTime by lazy { serviceHub.clock.instant() }
|
||||
|
||||
@Suspendable
|
||||
override fun call() {}
|
||||
}
|
||||
|
||||
|
||||
class RecordingCheckpointStorage : CheckpointStorage {
|
||||
|
||||
private val _checkpoints = ArrayList<Checkpoint>()
|
||||
|
Loading…
x
Reference in New Issue
Block a user