mirror of
https://github.com/corda/corda.git
synced 2025-04-07 19:34:41 +00:00
Initial checkpoint when protocol is first added
This commit is contained in:
parent
eb4c24abcb
commit
860353c4d4
@ -1,10 +1,11 @@
|
||||
package com.r3corda.contracts.testing
|
||||
|
||||
import com.r3corda.contracts.*
|
||||
import com.r3corda.contracts.cash.Cash
|
||||
import com.r3corda.contracts.cash.CASH_PROGRAM_ID
|
||||
import com.r3corda.contracts.cash.Cash
|
||||
import com.r3corda.core.contracts.Amount
|
||||
import com.r3corda.core.contracts.Contract
|
||||
import com.r3corda.core.contracts.DUMMY_PROGRAM_ID
|
||||
import com.r3corda.core.contracts.DummyContract
|
||||
import com.r3corda.core.crypto.NullPublicKey
|
||||
import com.r3corda.core.crypto.Party
|
||||
|
@ -24,7 +24,7 @@ import org.slf4j.LoggerFactory
|
||||
class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberScheduler, private val loggerName: String) : Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {
|
||||
|
||||
// These fields shouldn't be serialised, so they are marked @Transient.
|
||||
@Transient private var suspendAction: ((result: StateMachineManager.FiberRequest, fiber: ProtocolStateMachineImpl<*>) -> Unit)? = null
|
||||
@Transient private var suspendAction: ((result: StateMachineManager.FiberRequest) -> Unit)? = null
|
||||
@Transient private var receivedPayload: Any? = null
|
||||
@Transient lateinit override var serviceHub: ServiceHubInternal
|
||||
@Transient internal lateinit var actionOnEnd: () -> Unit
|
||||
@ -54,7 +54,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberS
|
||||
|
||||
fun prepareForResumeWith(serviceHub: ServiceHubInternal,
|
||||
receivedPayload: Any?,
|
||||
suspendAction: (StateMachineManager.FiberRequest, ProtocolStateMachineImpl<*>) -> Unit) {
|
||||
suspendAction: (StateMachineManager.FiberRequest) -> Unit) {
|
||||
this.serviceHub = serviceHub
|
||||
this.receivedPayload = receivedPayload
|
||||
this.suspendAction = suspendAction
|
||||
@ -108,7 +108,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberS
|
||||
private fun suspend(with: StateMachineManager.FiberRequest) {
|
||||
parkAndSerialize { fiber, serializer ->
|
||||
try {
|
||||
suspendAction!!(with, this)
|
||||
suspendAction!!(with)
|
||||
} catch (t: Throwable) {
|
||||
logger.warn("Captured exception which was swallowed by Quasar", t)
|
||||
// TODO to throw or not to throw, that is the question
|
||||
|
@ -177,7 +177,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
}
|
||||
}
|
||||
|
||||
private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint?) {
|
||||
private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) {
|
||||
stateMachines[psm] = checkpoint
|
||||
psm.actionOnEnd = {
|
||||
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||
@ -199,9 +199,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
|
||||
try {
|
||||
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
|
||||
val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null)
|
||||
checkpointStorage.addCheckpoint(checkpoint)
|
||||
// Need to add before iterating in case of immediate completion
|
||||
// TODO: create an initial checkpoint here
|
||||
initFiber(fiber, null)
|
||||
initFiber(fiber, checkpoint)
|
||||
executor.executeASAP {
|
||||
iterateStateMachine(fiber, null) {
|
||||
fiber.start()
|
||||
@ -233,21 +234,19 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
receivedPayload: Any?,
|
||||
resumeAction: (Any?) -> Unit) {
|
||||
executor.checkOnThread()
|
||||
psm.prepareForResumeWith(serviceHub, receivedPayload) { request, serialisedFiber ->
|
||||
psm.prepareForResumeWith(serviceHub, receivedPayload) { request ->
|
||||
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
|
||||
onNextSuspend(psm, request, serialisedFiber)
|
||||
onNextSuspend(psm, request)
|
||||
}
|
||||
psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" }
|
||||
resumeAction(receivedPayload)
|
||||
}
|
||||
|
||||
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>,
|
||||
request: FiberRequest,
|
||||
fiber: ProtocolStateMachineImpl<*>) {
|
||||
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: FiberRequest) {
|
||||
// We have a request to do something: send, receive, or send-and-receive.
|
||||
if (request is FiberRequest.ExpectingResponse<*>) {
|
||||
// Prepare a listener on the network that runs in the background thread when we receive a message.
|
||||
checkpointOnExpectingResponse(psm, request, serializeFiber(fiber))
|
||||
checkpointOnExpectingResponse(psm, request)
|
||||
}
|
||||
// If a non-null payload to send was provided, send it now.
|
||||
request.payload?.let {
|
||||
@ -267,11 +266,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
}
|
||||
}
|
||||
|
||||
private fun checkpointOnExpectingResponse(psm: ProtocolStateMachineImpl<*>,
|
||||
request: FiberRequest.ExpectingResponse<*>,
|
||||
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>) {
|
||||
private fun checkpointOnExpectingResponse(psm: ProtocolStateMachineImpl<*>, request: FiberRequest.ExpectingResponse<*>) {
|
||||
executor.checkOnThread()
|
||||
val topic = "${request.topic}.${request.sessionIDForReceive}"
|
||||
val serialisedFiber = serializeFiber(psm)
|
||||
updateCheckpoint(psm, serialisedFiber, topic, request.responseType, null)
|
||||
psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on topic $topic" }
|
||||
iterateOnResponse(psm, request.responseType, serialisedFiber, topic) {
|
||||
|
@ -14,7 +14,7 @@ import com.r3corda.node.services.persistence.DataVendingService
|
||||
import com.r3corda.node.services.wallet.NodeWalletService
|
||||
import java.time.Clock
|
||||
|
||||
class MockServices(
|
||||
open class MockServices(
|
||||
customWallet: WalletService? = null,
|
||||
val keyManagement: KeyManagementService? = null,
|
||||
val net: MessagingService? = null,
|
||||
|
@ -0,0 +1,74 @@
|
||||
package com.r3corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Fiber
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import com.r3corda.core.messaging.MessagingService
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.node.services.MockServices
|
||||
import com.r3corda.node.services.api.Checkpoint
|
||||
import com.r3corda.node.services.api.CheckpointStorage
|
||||
import com.r3corda.node.services.network.InMemoryMessagingNetwork
|
||||
import com.r3corda.node.utilities.AffinityExecutor
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.After
|
||||
import org.junit.Test
|
||||
import java.util.*
|
||||
|
||||
class StateMachineManagerTests {
|
||||
|
||||
val checkpointStorage = RecordingCheckpointStorage()
|
||||
val network = InMemoryMessagingNetwork().InMemoryMessaging(true, InMemoryMessagingNetwork.Handle(1, "mock"))
|
||||
val smm = createManager()
|
||||
|
||||
@After
|
||||
fun cleanUp() {
|
||||
network.stop()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `newly added protocol is preserved on restart`() {
|
||||
smm.add("mock", ProtocolWithoutCheckpoints())
|
||||
// Ensure we're restoring from the original add checkpoint
|
||||
assertThat(checkpointStorage.allCheckpoints).hasSize(1)
|
||||
val restoredProtocol = createManager().run {
|
||||
start()
|
||||
findStateMachines(ProtocolWithoutCheckpoints::class.java).single().first
|
||||
}
|
||||
assertThat(restoredProtocol.protocolStarted).isTrue()
|
||||
}
|
||||
|
||||
private fun createManager() = StateMachineManager(object : MockServices() {
|
||||
override val networkService: MessagingService get() = network
|
||||
}, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD)
|
||||
|
||||
|
||||
private class ProtocolWithoutCheckpoints : ProtocolLogic<Unit>() {
|
||||
|
||||
@Transient var protocolStarted = false
|
||||
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
protocolStarted = true
|
||||
Fiber.park()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class RecordingCheckpointStorage : CheckpointStorage {
|
||||
|
||||
private val _checkpoints = ArrayList<Checkpoint>()
|
||||
val allCheckpoints = ArrayList<Checkpoint>()
|
||||
|
||||
override fun addCheckpoint(checkpoint: Checkpoint) {
|
||||
_checkpoints.add(checkpoint)
|
||||
allCheckpoints.add(checkpoint)
|
||||
}
|
||||
|
||||
override fun removeCheckpoint(checkpoint: Checkpoint) {
|
||||
_checkpoints.remove(checkpoint)
|
||||
}
|
||||
|
||||
override val checkpoints: Iterable<Checkpoint> get() = _checkpoints
|
||||
}
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user