mirror of
https://github.com/corda/corda.git
synced 2025-06-23 01:19:00 +00:00
Initial checkpoint when protocol is first added
This commit is contained in:
@ -1,10 +1,11 @@
|
|||||||
package com.r3corda.contracts.testing
|
package com.r3corda.contracts.testing
|
||||||
|
|
||||||
import com.r3corda.contracts.*
|
import com.r3corda.contracts.*
|
||||||
import com.r3corda.contracts.cash.Cash
|
|
||||||
import com.r3corda.contracts.cash.CASH_PROGRAM_ID
|
import com.r3corda.contracts.cash.CASH_PROGRAM_ID
|
||||||
|
import com.r3corda.contracts.cash.Cash
|
||||||
import com.r3corda.core.contracts.Amount
|
import com.r3corda.core.contracts.Amount
|
||||||
import com.r3corda.core.contracts.Contract
|
import com.r3corda.core.contracts.Contract
|
||||||
|
import com.r3corda.core.contracts.DUMMY_PROGRAM_ID
|
||||||
import com.r3corda.core.contracts.DummyContract
|
import com.r3corda.core.contracts.DummyContract
|
||||||
import com.r3corda.core.crypto.NullPublicKey
|
import com.r3corda.core.crypto.NullPublicKey
|
||||||
import com.r3corda.core.crypto.Party
|
import com.r3corda.core.crypto.Party
|
||||||
|
@ -24,7 +24,7 @@ import org.slf4j.LoggerFactory
|
|||||||
class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberScheduler, private val loggerName: String) : Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {
|
class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberScheduler, private val loggerName: String) : Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {
|
||||||
|
|
||||||
// These fields shouldn't be serialised, so they are marked @Transient.
|
// These fields shouldn't be serialised, so they are marked @Transient.
|
||||||
@Transient private var suspendAction: ((result: StateMachineManager.FiberRequest, fiber: ProtocolStateMachineImpl<*>) -> Unit)? = null
|
@Transient private var suspendAction: ((result: StateMachineManager.FiberRequest) -> Unit)? = null
|
||||||
@Transient private var receivedPayload: Any? = null
|
@Transient private var receivedPayload: Any? = null
|
||||||
@Transient lateinit override var serviceHub: ServiceHubInternal
|
@Transient lateinit override var serviceHub: ServiceHubInternal
|
||||||
@Transient internal lateinit var actionOnEnd: () -> Unit
|
@Transient internal lateinit var actionOnEnd: () -> Unit
|
||||||
@ -54,7 +54,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberS
|
|||||||
|
|
||||||
fun prepareForResumeWith(serviceHub: ServiceHubInternal,
|
fun prepareForResumeWith(serviceHub: ServiceHubInternal,
|
||||||
receivedPayload: Any?,
|
receivedPayload: Any?,
|
||||||
suspendAction: (StateMachineManager.FiberRequest, ProtocolStateMachineImpl<*>) -> Unit) {
|
suspendAction: (StateMachineManager.FiberRequest) -> Unit) {
|
||||||
this.serviceHub = serviceHub
|
this.serviceHub = serviceHub
|
||||||
this.receivedPayload = receivedPayload
|
this.receivedPayload = receivedPayload
|
||||||
this.suspendAction = suspendAction
|
this.suspendAction = suspendAction
|
||||||
@ -108,7 +108,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>, scheduler: FiberS
|
|||||||
private fun suspend(with: StateMachineManager.FiberRequest) {
|
private fun suspend(with: StateMachineManager.FiberRequest) {
|
||||||
parkAndSerialize { fiber, serializer ->
|
parkAndSerialize { fiber, serializer ->
|
||||||
try {
|
try {
|
||||||
suspendAction!!(with, this)
|
suspendAction!!(with)
|
||||||
} catch (t: Throwable) {
|
} catch (t: Throwable) {
|
||||||
logger.warn("Captured exception which was swallowed by Quasar", t)
|
logger.warn("Captured exception which was swallowed by Quasar", t)
|
||||||
// TODO to throw or not to throw, that is the question
|
// TODO to throw or not to throw, that is the question
|
||||||
|
@ -177,7 +177,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint?) {
|
private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) {
|
||||||
stateMachines[psm] = checkpoint
|
stateMachines[psm] = checkpoint
|
||||||
psm.actionOnEnd = {
|
psm.actionOnEnd = {
|
||||||
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||||
@ -199,9 +199,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
|
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
|
||||||
try {
|
try {
|
||||||
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
|
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
|
||||||
|
val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null)
|
||||||
|
checkpointStorage.addCheckpoint(checkpoint)
|
||||||
// Need to add before iterating in case of immediate completion
|
// Need to add before iterating in case of immediate completion
|
||||||
// TODO: create an initial checkpoint here
|
initFiber(fiber, checkpoint)
|
||||||
initFiber(fiber, null)
|
|
||||||
executor.executeASAP {
|
executor.executeASAP {
|
||||||
iterateStateMachine(fiber, null) {
|
iterateStateMachine(fiber, null) {
|
||||||
fiber.start()
|
fiber.start()
|
||||||
@ -233,21 +234,19 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
receivedPayload: Any?,
|
receivedPayload: Any?,
|
||||||
resumeAction: (Any?) -> Unit) {
|
resumeAction: (Any?) -> Unit) {
|
||||||
executor.checkOnThread()
|
executor.checkOnThread()
|
||||||
psm.prepareForResumeWith(serviceHub, receivedPayload) { request, serialisedFiber ->
|
psm.prepareForResumeWith(serviceHub, receivedPayload) { request ->
|
||||||
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
|
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
|
||||||
onNextSuspend(psm, request, serialisedFiber)
|
onNextSuspend(psm, request)
|
||||||
}
|
}
|
||||||
psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" }
|
psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" }
|
||||||
resumeAction(receivedPayload)
|
resumeAction(receivedPayload)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>,
|
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: FiberRequest) {
|
||||||
request: FiberRequest,
|
|
||||||
fiber: ProtocolStateMachineImpl<*>) {
|
|
||||||
// We have a request to do something: send, receive, or send-and-receive.
|
// We have a request to do something: send, receive, or send-and-receive.
|
||||||
if (request is FiberRequest.ExpectingResponse<*>) {
|
if (request is FiberRequest.ExpectingResponse<*>) {
|
||||||
// Prepare a listener on the network that runs in the background thread when we receive a message.
|
// Prepare a listener on the network that runs in the background thread when we receive a message.
|
||||||
checkpointOnExpectingResponse(psm, request, serializeFiber(fiber))
|
checkpointOnExpectingResponse(psm, request)
|
||||||
}
|
}
|
||||||
// If a non-null payload to send was provided, send it now.
|
// If a non-null payload to send was provided, send it now.
|
||||||
request.payload?.let {
|
request.payload?.let {
|
||||||
@ -267,11 +266,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun checkpointOnExpectingResponse(psm: ProtocolStateMachineImpl<*>,
|
private fun checkpointOnExpectingResponse(psm: ProtocolStateMachineImpl<*>, request: FiberRequest.ExpectingResponse<*>) {
|
||||||
request: FiberRequest.ExpectingResponse<*>,
|
|
||||||
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>) {
|
|
||||||
executor.checkOnThread()
|
executor.checkOnThread()
|
||||||
val topic = "${request.topic}.${request.sessionIDForReceive}"
|
val topic = "${request.topic}.${request.sessionIDForReceive}"
|
||||||
|
val serialisedFiber = serializeFiber(psm)
|
||||||
updateCheckpoint(psm, serialisedFiber, topic, request.responseType, null)
|
updateCheckpoint(psm, serialisedFiber, topic, request.responseType, null)
|
||||||
psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on topic $topic" }
|
psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on topic $topic" }
|
||||||
iterateOnResponse(psm, request.responseType, serialisedFiber, topic) {
|
iterateOnResponse(psm, request.responseType, serialisedFiber, topic) {
|
||||||
|
@ -14,7 +14,7 @@ import com.r3corda.node.services.persistence.DataVendingService
|
|||||||
import com.r3corda.node.services.wallet.NodeWalletService
|
import com.r3corda.node.services.wallet.NodeWalletService
|
||||||
import java.time.Clock
|
import java.time.Clock
|
||||||
|
|
||||||
class MockServices(
|
open class MockServices(
|
||||||
customWallet: WalletService? = null,
|
customWallet: WalletService? = null,
|
||||||
val keyManagement: KeyManagementService? = null,
|
val keyManagement: KeyManagementService? = null,
|
||||||
val net: MessagingService? = null,
|
val net: MessagingService? = null,
|
||||||
|
@ -0,0 +1,74 @@
|
|||||||
|
package com.r3corda.node.services.statemachine
|
||||||
|
|
||||||
|
import co.paralleluniverse.fibers.Fiber
|
||||||
|
import co.paralleluniverse.fibers.Suspendable
|
||||||
|
import com.r3corda.core.messaging.MessagingService
|
||||||
|
import com.r3corda.core.protocols.ProtocolLogic
|
||||||
|
import com.r3corda.node.services.MockServices
|
||||||
|
import com.r3corda.node.services.api.Checkpoint
|
||||||
|
import com.r3corda.node.services.api.CheckpointStorage
|
||||||
|
import com.r3corda.node.services.network.InMemoryMessagingNetwork
|
||||||
|
import com.r3corda.node.utilities.AffinityExecutor
|
||||||
|
import org.assertj.core.api.Assertions.assertThat
|
||||||
|
import org.junit.After
|
||||||
|
import org.junit.Test
|
||||||
|
import java.util.*
|
||||||
|
|
||||||
|
class StateMachineManagerTests {
|
||||||
|
|
||||||
|
val checkpointStorage = RecordingCheckpointStorage()
|
||||||
|
val network = InMemoryMessagingNetwork().InMemoryMessaging(true, InMemoryMessagingNetwork.Handle(1, "mock"))
|
||||||
|
val smm = createManager()
|
||||||
|
|
||||||
|
@After
|
||||||
|
fun cleanUp() {
|
||||||
|
network.stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `newly added protocol is preserved on restart`() {
|
||||||
|
smm.add("mock", ProtocolWithoutCheckpoints())
|
||||||
|
// Ensure we're restoring from the original add checkpoint
|
||||||
|
assertThat(checkpointStorage.allCheckpoints).hasSize(1)
|
||||||
|
val restoredProtocol = createManager().run {
|
||||||
|
start()
|
||||||
|
findStateMachines(ProtocolWithoutCheckpoints::class.java).single().first
|
||||||
|
}
|
||||||
|
assertThat(restoredProtocol.protocolStarted).isTrue()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun createManager() = StateMachineManager(object : MockServices() {
|
||||||
|
override val networkService: MessagingService get() = network
|
||||||
|
}, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD)
|
||||||
|
|
||||||
|
|
||||||
|
private class ProtocolWithoutCheckpoints : ProtocolLogic<Unit>() {
|
||||||
|
|
||||||
|
@Transient var protocolStarted = false
|
||||||
|
|
||||||
|
@Suspendable
|
||||||
|
override fun call() {
|
||||||
|
protocolStarted = true
|
||||||
|
Fiber.park()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingCheckpointStorage : CheckpointStorage {
|
||||||
|
|
||||||
|
private val _checkpoints = ArrayList<Checkpoint>()
|
||||||
|
val allCheckpoints = ArrayList<Checkpoint>()
|
||||||
|
|
||||||
|
override fun addCheckpoint(checkpoint: Checkpoint) {
|
||||||
|
_checkpoints.add(checkpoint)
|
||||||
|
allCheckpoints.add(checkpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun removeCheckpoint(checkpoint: Checkpoint) {
|
||||||
|
_checkpoints.remove(checkpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
override val checkpoints: Iterable<Checkpoint> get() = _checkpoints
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Reference in New Issue
Block a user