diff --git a/contracts/src/main/kotlin/com/r3corda/contracts/testing/TestUtils.kt b/contracts/src/main/kotlin/com/r3corda/contracts/testing/TestUtils.kt index 1ebf87d512..692d25ceab 100644 --- a/contracts/src/main/kotlin/com/r3corda/contracts/testing/TestUtils.kt +++ b/contracts/src/main/kotlin/com/r3corda/contracts/testing/TestUtils.kt @@ -1,10 +1,11 @@ package com.r3corda.contracts.testing import com.r3corda.contracts.* -import com.r3corda.contracts.cash.Cash import com.r3corda.contracts.cash.CASH_PROGRAM_ID +import com.r3corda.contracts.cash.Cash import com.r3corda.core.contracts.Amount import com.r3corda.core.contracts.Contract +import com.r3corda.core.contracts.DUMMY_PROGRAM_ID import com.r3corda.core.contracts.DummyContract import com.r3corda.core.crypto.NullPublicKey import com.r3corda.core.crypto.Party diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt index 38cec65a3a..7c8881ef75 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt @@ -24,7 +24,7 @@ import org.slf4j.LoggerFactory class ProtocolStateMachineImpl(val logic: ProtocolLogic, scheduler: FiberScheduler, private val loggerName: String) : Fiber("protocol", scheduler), ProtocolStateMachine { // These fields shouldn't be serialised, so they are marked @Transient. - @Transient private var suspendAction: ((result: StateMachineManager.FiberRequest, fiber: ProtocolStateMachineImpl<*>) -> Unit)? = null + @Transient private var suspendAction: ((result: StateMachineManager.FiberRequest) -> Unit)? = null @Transient private var receivedPayload: Any? = null @Transient lateinit override var serviceHub: ServiceHubInternal @Transient internal lateinit var actionOnEnd: () -> Unit @@ -54,7 +54,7 @@ class ProtocolStateMachineImpl(val logic: ProtocolLogic, scheduler: FiberS fun prepareForResumeWith(serviceHub: ServiceHubInternal, receivedPayload: Any?, - suspendAction: (StateMachineManager.FiberRequest, ProtocolStateMachineImpl<*>) -> Unit) { + suspendAction: (StateMachineManager.FiberRequest) -> Unit) { this.serviceHub = serviceHub this.receivedPayload = receivedPayload this.suspendAction = suspendAction @@ -108,7 +108,7 @@ class ProtocolStateMachineImpl(val logic: ProtocolLogic, scheduler: FiberS private fun suspend(with: StateMachineManager.FiberRequest) { parkAndSerialize { fiber, serializer -> try { - suspendAction!!(with, this) + suspendAction!!(with) } catch (t: Throwable) { logger.warn("Captured exception which was swallowed by Quasar", t) // TODO to throw or not to throw, that is the question diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt index 25a9022089..ca40b18565 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt @@ -177,7 +177,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService } } - private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint?) { + private fun initFiber(psm: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) { stateMachines[psm] = checkpoint psm.actionOnEnd = { psm.logic.progressTracker?.currentStep = ProgressTracker.DONE @@ -199,9 +199,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService fun add(loggerName: String, logic: ProtocolLogic): ListenableFuture { try { val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName) + val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null) + checkpointStorage.addCheckpoint(checkpoint) // Need to add before iterating in case of immediate completion - // TODO: create an initial checkpoint here - initFiber(fiber, null) + initFiber(fiber, checkpoint) executor.executeASAP { iterateStateMachine(fiber, null) { fiber.start() @@ -233,21 +234,19 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService receivedPayload: Any?, resumeAction: (Any?) -> Unit) { executor.checkOnThread() - psm.prepareForResumeWith(serviceHub, receivedPayload) { request, serialisedFiber -> + psm.prepareForResumeWith(serviceHub, receivedPayload) { request -> psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" } - onNextSuspend(psm, request, serialisedFiber) + onNextSuspend(psm, request) } psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" } resumeAction(receivedPayload) } - private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, - request: FiberRequest, - fiber: ProtocolStateMachineImpl<*>) { + private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: FiberRequest) { // We have a request to do something: send, receive, or send-and-receive. if (request is FiberRequest.ExpectingResponse<*>) { // Prepare a listener on the network that runs in the background thread when we receive a message. - checkpointOnExpectingResponse(psm, request, serializeFiber(fiber)) + checkpointOnExpectingResponse(psm, request) } // If a non-null payload to send was provided, send it now. request.payload?.let { @@ -267,11 +266,10 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService } } - private fun checkpointOnExpectingResponse(psm: ProtocolStateMachineImpl<*>, - request: FiberRequest.ExpectingResponse<*>, - serialisedFiber: SerializedBytes>) { + private fun checkpointOnExpectingResponse(psm: ProtocolStateMachineImpl<*>, request: FiberRequest.ExpectingResponse<*>) { executor.checkOnThread() val topic = "${request.topic}.${request.sessionIDForReceive}" + val serialisedFiber = serializeFiber(psm) updateCheckpoint(psm, serialisedFiber, topic, request.responseType, null) psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on topic $topic" } iterateOnResponse(psm, request.responseType, serialisedFiber, topic) { diff --git a/node/src/test/kotlin/com/r3corda/node/services/MockServices.kt b/node/src/test/kotlin/com/r3corda/node/services/MockServices.kt index 9fb51f1065..736f97a1c9 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/MockServices.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/MockServices.kt @@ -14,7 +14,7 @@ import com.r3corda.node.services.persistence.DataVendingService import com.r3corda.node.services.wallet.NodeWalletService import java.time.Clock -class MockServices( +open class MockServices( customWallet: WalletService? = null, val keyManagement: KeyManagementService? = null, val net: MessagingService? = null, diff --git a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt new file mode 100644 index 0000000000..eb6552d791 --- /dev/null +++ b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt @@ -0,0 +1,74 @@ +package com.r3corda.node.services.statemachine + +import co.paralleluniverse.fibers.Fiber +import co.paralleluniverse.fibers.Suspendable +import com.r3corda.core.messaging.MessagingService +import com.r3corda.core.protocols.ProtocolLogic +import com.r3corda.node.services.MockServices +import com.r3corda.node.services.api.Checkpoint +import com.r3corda.node.services.api.CheckpointStorage +import com.r3corda.node.services.network.InMemoryMessagingNetwork +import com.r3corda.node.utilities.AffinityExecutor +import org.assertj.core.api.Assertions.assertThat +import org.junit.After +import org.junit.Test +import java.util.* + +class StateMachineManagerTests { + + val checkpointStorage = RecordingCheckpointStorage() + val network = InMemoryMessagingNetwork().InMemoryMessaging(true, InMemoryMessagingNetwork.Handle(1, "mock")) + val smm = createManager() + + @After + fun cleanUp() { + network.stop() + } + + @Test + fun `newly added protocol is preserved on restart`() { + smm.add("mock", ProtocolWithoutCheckpoints()) + // Ensure we're restoring from the original add checkpoint + assertThat(checkpointStorage.allCheckpoints).hasSize(1) + val restoredProtocol = createManager().run { + start() + findStateMachines(ProtocolWithoutCheckpoints::class.java).single().first + } + assertThat(restoredProtocol.protocolStarted).isTrue() + } + + private fun createManager() = StateMachineManager(object : MockServices() { + override val networkService: MessagingService get() = network + }, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD) + + + private class ProtocolWithoutCheckpoints : ProtocolLogic() { + + @Transient var protocolStarted = false + + @Suspendable + override fun call() { + protocolStarted = true + Fiber.park() + } + } + + + class RecordingCheckpointStorage : CheckpointStorage { + + private val _checkpoints = ArrayList() + val allCheckpoints = ArrayList() + + override fun addCheckpoint(checkpoint: Checkpoint) { + _checkpoints.add(checkpoint) + allCheckpoints.add(checkpoint) + } + + override fun removeCheckpoint(checkpoint: Checkpoint) { + _checkpoints.remove(checkpoint) + } + + override val checkpoints: Iterable get() = _checkpoints + } + +} \ No newline at end of file