mirror of
https://github.com/corda/corda.git
synced 2025-06-19 07:38:22 +00:00
Protocols can use the serviceHub lazily in their constructors
This commit is contained in:
@ -31,7 +31,11 @@ abstract class ProtocolLogic<T> {
|
|||||||
/** This is where you should log things to. */
|
/** This is where you should log things to. */
|
||||||
val logger: Logger get() = psm.logger
|
val logger: Logger get() = psm.logger
|
||||||
|
|
||||||
/** Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts */
|
/**
|
||||||
|
* Provides access to big, heavy classes that may be reconstructed from time to time, e.g. across restarts. It is
|
||||||
|
* only available once the protocol has started, which means it cannnot be accessed in the constructor. Either
|
||||||
|
* access this lazily or from inside [call].
|
||||||
|
*/
|
||||||
val serviceHub: ServiceHub get() = psm.serviceHub
|
val serviceHub: ServiceHub get() = psm.serviceHub
|
||||||
|
|
||||||
// Kotlin helpers that allow the use of generic types.
|
// Kotlin helpers that allow the use of generic types.
|
||||||
|
@ -21,13 +21,16 @@ import org.slf4j.LoggerFactory
|
|||||||
* a protocol invokes a sub-protocol, then it will pass along the PSM to the child. The call method of the topmost
|
* a protocol invokes a sub-protocol, then it will pass along the PSM to the child. The call method of the topmost
|
||||||
* logic element gets to return the value that the entire state machine resolves to.
|
* logic element gets to return the value that the entire state machine resolves to.
|
||||||
*/
|
*/
|
||||||
class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberScheduler, private val loggerName: String) : Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {
|
class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
|
||||||
|
scheduler: FiberScheduler,
|
||||||
|
private val loggerName: String)
|
||||||
|
: Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {
|
||||||
|
|
||||||
// These fields shouldn't be serialised, so they are marked @Transient.
|
// These fields shouldn't be serialised, so they are marked @Transient.
|
||||||
@Transient private var suspendAction: ((result: StateMachineManager.FiberRequest) -> Unit)? = null
|
|
||||||
@Transient private var receivedPayload: Any? = null
|
|
||||||
@Transient lateinit override var serviceHub: ServiceHubInternal
|
@Transient lateinit override var serviceHub: ServiceHubInternal
|
||||||
|
@Transient internal lateinit var suspendAction: (StateMachineManager.FiberRequest) -> Unit
|
||||||
@Transient internal lateinit var actionOnEnd: () -> Unit
|
@Transient internal lateinit var actionOnEnd: () -> Unit
|
||||||
|
@Transient internal var receivedPayload: Any? = null
|
||||||
|
|
||||||
@Transient private var _logger: Logger? = null
|
@Transient private var _logger: Logger? = null
|
||||||
override val logger: Logger get() {
|
override val logger: Logger get() {
|
||||||
@ -52,14 +55,6 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberS
|
|||||||
logic.psm = this
|
logic.psm = this
|
||||||
}
|
}
|
||||||
|
|
||||||
fun prepareForResumeWith(serviceHub: ServiceHubInternal,
|
|
||||||
receivedPayload: Any?,
|
|
||||||
suspendAction: (StateMachineManager.FiberRequest) -> Unit) {
|
|
||||||
this.serviceHub = serviceHub
|
|
||||||
this.receivedPayload = receivedPayload
|
|
||||||
this.suspendAction = suspendAction
|
|
||||||
}
|
|
||||||
|
|
||||||
@Suspendable @Suppress("UNCHECKED_CAST")
|
@Suspendable @Suppress("UNCHECKED_CAST")
|
||||||
override fun run(): R {
|
override fun run(): R {
|
||||||
val result = try {
|
val result = try {
|
||||||
|
@ -6,16 +6,13 @@ import co.paralleluniverse.io.serialization.kryo.KryoSerializer
|
|||||||
import com.codahale.metrics.Gauge
|
import com.codahale.metrics.Gauge
|
||||||
import com.esotericsoftware.kryo.Kryo
|
import com.esotericsoftware.kryo.Kryo
|
||||||
import com.google.common.base.Throwables
|
import com.google.common.base.Throwables
|
||||||
import com.google.common.util.concurrent.Futures
|
|
||||||
import com.google.common.util.concurrent.ListenableFuture
|
import com.google.common.util.concurrent.ListenableFuture
|
||||||
import com.google.common.util.concurrent.SettableFuture
|
|
||||||
import com.r3corda.core.abbreviate
|
import com.r3corda.core.abbreviate
|
||||||
import com.r3corda.core.messaging.MessageRecipients
|
import com.r3corda.core.messaging.MessageRecipients
|
||||||
import com.r3corda.core.messaging.runOnNextMessage
|
import com.r3corda.core.messaging.runOnNextMessage
|
||||||
import com.r3corda.core.messaging.send
|
import com.r3corda.core.messaging.send
|
||||||
import com.r3corda.core.protocols.ProtocolLogic
|
import com.r3corda.core.protocols.ProtocolLogic
|
||||||
import com.r3corda.core.serialization.*
|
import com.r3corda.core.serialization.*
|
||||||
import com.r3corda.core.then
|
|
||||||
import com.r3corda.core.utilities.ProgressTracker
|
import com.r3corda.core.utilities.ProgressTracker
|
||||||
import com.r3corda.core.utilities.trace
|
import com.r3corda.core.utilities.trace
|
||||||
import com.r3corda.node.services.api.Checkpoint
|
import com.r3corda.node.services.api.Checkpoint
|
||||||
@ -115,12 +112,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
}
|
}
|
||||||
|
|
||||||
fun start() {
|
fun start() {
|
||||||
checkpointStorage.checkpoints.forEach { restoreCheckpoint(it) }
|
checkpointStorage.checkpoints.forEach { restoreFromCheckpoint(it) }
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun restoreCheckpoint(checkpoint: Checkpoint) {
|
private fun restoreFromCheckpoint(checkpoint: Checkpoint) {
|
||||||
val fiber = deserializeFiber(checkpoint.serialisedFiber)
|
val fiber = deserializeFiber(checkpoint.serialisedFiber)
|
||||||
initFiber(fiber, checkpoint)
|
initFiber(fiber, { checkpoint })
|
||||||
|
|
||||||
val topic = checkpoint.awaitingTopic
|
val topic = checkpoint.awaitingTopic
|
||||||
if (topic != null) {
|
if (topic != null) {
|
||||||
@ -177,8 +174,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) {
|
private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint) {
|
||||||
stateMachines[psm] = checkpoint
|
psm.serviceHub = serviceHub
|
||||||
|
psm.suspendAction = { request ->
|
||||||
|
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
|
||||||
|
onNextSuspend(psm, request)
|
||||||
|
}
|
||||||
psm.actionOnEnd = {
|
psm.actionOnEnd = {
|
||||||
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||||
val finalCheckpoint = stateMachines.remove(psm)
|
val finalCheckpoint = stateMachines.remove(psm)
|
||||||
@ -188,6 +189,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
totalFinishedProtocols.inc()
|
totalFinishedProtocols.inc()
|
||||||
notifyChangeObservers(psm, AddOrRemove.REMOVE)
|
notifyChangeObservers(psm, AddOrRemove.REMOVE)
|
||||||
}
|
}
|
||||||
|
stateMachines[psm] = startingCheckpoint()
|
||||||
notifyChangeObservers(psm, AddOrRemove.ADD)
|
notifyChangeObservers(psm, AddOrRemove.ADD)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,10 +201,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
|
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
|
||||||
try {
|
try {
|
||||||
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
|
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
|
||||||
val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null)
|
|
||||||
checkpointStorage.addCheckpoint(checkpoint)
|
|
||||||
// Need to add before iterating in case of immediate completion
|
// Need to add before iterating in case of immediate completion
|
||||||
initFiber(fiber, checkpoint)
|
initFiber(fiber) {
|
||||||
|
val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null)
|
||||||
|
checkpointStorage.addCheckpoint(checkpoint)
|
||||||
|
checkpoint
|
||||||
|
}
|
||||||
executor.executeASAP {
|
executor.executeASAP {
|
||||||
iterateStateMachine(fiber, null) {
|
iterateStateMachine(fiber, null) {
|
||||||
fiber.start()
|
fiber.start()
|
||||||
@ -234,10 +238,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
receivedPayload: Any?,
|
receivedPayload: Any?,
|
||||||
resumeAction: (Any?) -> Unit) {
|
resumeAction: (Any?) -> Unit) {
|
||||||
executor.checkOnThread()
|
executor.checkOnThread()
|
||||||
psm.prepareForResumeWith(serviceHub, receivedPayload) { request ->
|
psm.receivedPayload = receivedPayload
|
||||||
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
|
|
||||||
onNextSuspend(psm, request)
|
|
||||||
}
|
|
||||||
psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" }
|
psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" }
|
||||||
resumeAction(receivedPayload)
|
resumeAction(receivedPayload)
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ class StateMachineManagerTests {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `newly added protocol is preserved on restart`() {
|
fun `newly added protocol is preserved on restart`() {
|
||||||
smm.add("mock", ProtocolWithoutCheckpoints())
|
smm.add("test", ProtocolWithoutCheckpoints())
|
||||||
// Ensure we're restoring from the original add checkpoint
|
// Ensure we're restoring from the original add checkpoint
|
||||||
assertThat(checkpointStorage.allCheckpoints).hasSize(1)
|
assertThat(checkpointStorage.allCheckpoints).hasSize(1)
|
||||||
val restoredProtocol = createManager().run {
|
val restoredProtocol = createManager().run {
|
||||||
@ -37,11 +37,19 @@ class StateMachineManagerTests {
|
|||||||
assertThat(restoredProtocol.protocolStarted).isTrue()
|
assertThat(restoredProtocol.protocolStarted).isTrue()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `protocol can lazily use the serviceHub in its constructor`() {
|
||||||
|
val protocol = ProtocolWithLazyServiceHub()
|
||||||
|
smm.add("test", protocol)
|
||||||
|
assertThat(protocol.lazyTime).isNotNull()
|
||||||
|
}
|
||||||
|
|
||||||
private fun createManager() = StateMachineManager(object : MockServices() {
|
private fun createManager() = StateMachineManager(object : MockServices() {
|
||||||
override val networkService: MessagingService get() = network
|
override val networkService: MessagingService get() = network
|
||||||
}, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD)
|
}, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
private class ProtocolWithoutCheckpoints : ProtocolLogic<Unit>() {
|
private class ProtocolWithoutCheckpoints : ProtocolLogic<Unit>() {
|
||||||
|
|
||||||
@Transient var protocolStarted = false
|
@Transient var protocolStarted = false
|
||||||
@ -54,6 +62,15 @@ class StateMachineManagerTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private class ProtocolWithLazyServiceHub : ProtocolLogic<Unit>() {
|
||||||
|
|
||||||
|
val lazyTime by lazy { serviceHub.clock.instant() }
|
||||||
|
|
||||||
|
@Suspendable
|
||||||
|
override fun call() {}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class RecordingCheckpointStorage : CheckpointStorage {
|
class RecordingCheckpointStorage : CheckpointStorage {
|
||||||
|
|
||||||
private val _checkpoints = ArrayList<Checkpoint>()
|
private val _checkpoints = ArrayList<Checkpoint>()
|
||||||
|
Reference in New Issue
Block a user