diff --git a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolStateMachine.kt b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolStateMachine.kt index 54a972d26f..65ab9b43fa 100644 --- a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolStateMachine.kt +++ b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolStateMachine.kt @@ -7,17 +7,20 @@ import com.r3corda.core.node.ServiceHub import com.r3corda.core.utilities.UntrustworthyData import org.slf4j.Logger - /** * The interface of [ProtocolStateMachineImpl] exposing methods and properties required by ProtocolLogic for compilation. */ interface ProtocolStateMachine { @Suspendable - fun sendAndReceive(topic: String, destination: Party, sessionIDForSend: Long, sessionIDForReceive: Long, - payload: Any, recvType: Class): UntrustworthyData + fun sendAndReceive(topic: String, + destination: Party, + sessionIDForSend: Long, + sessionIDForReceive: Long, + payload: Any, + receiveType: Class): UntrustworthyData @Suspendable - fun receive(topic: String, sessionIDForReceive: Long, recvType: Class): UntrustworthyData + fun receive(topic: String, sessionIDForReceive: Long, receiveType: Class): UntrustworthyData @Suspendable fun send(topic: String, destination: Party, sessionID: Long, payload: Any) diff --git a/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt b/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt index 9426e3cae5..e742d29f9d 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt @@ -1,7 +1,7 @@ package com.r3corda.node.services.api import com.r3corda.core.serialization.SerializedBytes -import com.r3corda.node.services.statemachine.FiberRequest +import com.r3corda.node.services.statemachine.ProtocolIORequest import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl /** @@ -32,5 +32,6 @@ interface CheckpointStorage { // This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo). data class Checkpoint( val serialisedFiber: SerializedBytes>, - val request: FiberRequest? + val request: ProtocolIORequest?, + val receivedPayload: Any? ) \ No newline at end of file diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/FiberRequest.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/FiberRequest.kt deleted file mode 100644 index b3fd7911d8..0000000000 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/FiberRequest.kt +++ /dev/null @@ -1,86 +0,0 @@ -package com.r3corda.node.services.statemachine - -import com.r3corda.core.crypto.Party -import com.r3corda.core.messaging.TopicSession - -// TODO: Clean this up -sealed class FiberRequest(val topic: String, - val destination: Party?, - val sessionIDForSend: Long, - val sessionIDForReceive: Long, - val payload: Any?) { - // This is used to identify where we suspended, in case of message mismatch errors and other things where we - // don't have the original stack trace because it's in a suspended fiber. - @Transient - val stackTraceInCaseOfProblems: StackSnapshot? = StackSnapshot() - - val receiveTopicSession: TopicSession - get() = TopicSession(topic, sessionIDForReceive) - - - override fun equals(other: Any?): Boolean - = if (other is FiberRequest) { - topic == other.topic - && destination == other.destination - && sessionIDForSend == other.sessionIDForSend - && sessionIDForReceive == other.sessionIDForReceive - && payload == other.payload - } else - false - - override fun hashCode(): Int { - var hash = 1L - - hash = (hash * 31) + topic.hashCode() - hash = (hash * 31) + if (destination == null) - 0 - else - destination.hashCode() - hash = (hash * 31) + sessionIDForReceive - hash = (hash * 31) + sessionIDForReceive - hash = (hash * 31) + if (payload == null) - 0 - else - payload.hashCode() - - return hash.toInt() - } - - /** - * A fiber which is expecting a message response. - */ - class ExpectingResponse( - topic: String, - destination: Party?, - sessionIDForSend: Long, - sessionIDForReceive: Long, - obj: Any?, - type: Class - ) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) { - private val responseTypeName: String = type.name - - override fun equals(other: Any?): Boolean { - return if (other is ExpectingResponse<*>) { - super.equals(other) && responseTypeName == other.responseTypeName - } else - false - } - - override fun toString(): String { - return "Expecting response via topic $receiveTopicSession of type $responseTypeName" - } - - // We have to do an unchecked cast, but unless the serialized form is damaged, this was - // correct when the request was instantiated - @Suppress("UNCHECKED_CAST") - val responseType: Class - get() = Class.forName(responseTypeName) as Class - } - - class NotExpectingResponse( - topic: String, - destination: Party, - sessionIDForSend: Long, - obj: Any? - ) : FiberRequest(topic, destination, sessionIDForSend, -1, obj) -} \ No newline at end of file diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolIORequest.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolIORequest.kt new file mode 100644 index 0000000000..51ae58bfa7 --- /dev/null +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolIORequest.kt @@ -0,0 +1,51 @@ +package com.r3corda.node.services.statemachine + +import com.r3corda.core.crypto.Party +import com.r3corda.core.messaging.TopicSession + +// TODO revisit when Kotlin 1.1 is released and data classes can extend other classes +interface ProtocolIORequest { + // This is used to identify where we suspended, in case of message mismatch errors and other things where we + // don't have the original stack trace because it's in a suspended fiber. + val stackTraceInCaseOfProblems: StackSnapshot + val topic: String +} + +interface SendRequest : ProtocolIORequest { + val destination: Party + val payload: Any + val sendSessionID: Long +} + +interface ReceiveRequest : ProtocolIORequest { + val receiveType: Class + val receiveSessionID: Long + val receiveTopicSession: TopicSession get() = TopicSession(topic, receiveSessionID) +} + +data class SendAndReceive(override val topic: String, + override val destination: Party, + override val payload: Any, + override val sendSessionID: Long, + override val receiveType: Class, + override val receiveSessionID: Long) : SendRequest, ReceiveRequest { + @Transient + override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() +} + +data class ReceiveOnly(override val topic: String, + override val receiveType: Class, + override val receiveSessionID: Long) : ReceiveRequest { + @Transient + override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() +} + +data class SendOnly(override val destination: Party, + override val topic: String, + override val payload: Any, + override val sendSessionID: Long) : SendRequest { + @Transient + override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() +} + +class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem") diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt index 8a4400b620..d3a55dc5c8 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt @@ -29,7 +29,7 @@ class ProtocolStateMachineImpl(val logic: ProtocolLogic, // These fields shouldn't be serialised, so they are marked @Transient. @Transient lateinit override var serviceHub: ServiceHubInternal - @Transient internal lateinit var suspendAction: (FiberRequest) -> Unit + @Transient internal lateinit var suspendAction: (ProtocolIORequest) -> Unit @Transient internal lateinit var actionOnEnd: () -> Unit @Transient internal var receivedPayload: Any? = null @@ -77,43 +77,40 @@ class ProtocolStateMachineImpl(val logic: ProtocolLogic, return result } - @Suspendable @Suppress("UNCHECKED_CAST") - private fun suspendAndExpectReceive(with: FiberRequest): UntrustworthyData { - suspend(with) + @Suspendable + private fun suspendAndExpectReceive(receiveRequest: ReceiveRequest): UntrustworthyData { + suspend(receiveRequest) check(receivedPayload != null) { "Expected to receive something" } - val untrustworthy = UntrustworthyData(receivedPayload as T) + val untrustworthy = UntrustworthyData(receiveRequest.receiveType.cast(receivedPayload)) receivedPayload = null return untrustworthy } - @Suspendable @Suppress("UNCHECKED_CAST") + @Suspendable override fun sendAndReceive(topic: String, destination: Party, sessionIDForSend: Long, sessionIDForReceive: Long, payload: Any, - recvType: Class): UntrustworthyData { - val result = FiberRequest.ExpectingResponse(topic, destination, sessionIDForSend, sessionIDForReceive, payload, recvType) - return suspendAndExpectReceive(result) + receiveType: Class): UntrustworthyData { + return suspendAndExpectReceive(SendAndReceive(topic, destination, payload, sessionIDForSend, receiveType, sessionIDForReceive)) } @Suspendable - override fun receive(topic: String, sessionIDForReceive: Long, recvType: Class): UntrustworthyData { - val result = FiberRequest.ExpectingResponse(topic, null, -1, sessionIDForReceive, null, recvType) - return suspendAndExpectReceive(result) + override fun receive(topic: String, sessionIDForReceive: Long, receiveType: Class): UntrustworthyData { + return suspendAndExpectReceive(ReceiveOnly(topic, receiveType, sessionIDForReceive)) } @Suspendable override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) { - val result = FiberRequest.NotExpectingResponse(topic, destination, sessionID, payload) - suspend(result) + suspend(SendOnly(destination, topic, payload, sessionID)) } @Suspendable - private fun suspend(with: FiberRequest) { + private fun suspend(protocolIORequest: ProtocolIORequest) { parkAndSerialize { fiber, serializer -> try { - suspendAction(with) + suspendAction(protocolIORequest) } catch (t: Throwable) { // Do not throw exception again - Quasar completely bins it. logger.warn("Captured exception which was swallowed by Quasar", t) diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt index a44881da14..ffb3af621e 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt @@ -125,38 +125,25 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService val fiber = deserializeFiber(checkpoint.serialisedFiber) initFiber(fiber, { checkpoint }) - when (checkpoint.request) { - is FiberRequest.ExpectingResponse<*> -> { - val topic = checkpoint.request.receiveTopicSession - val awaitingPayloadType = checkpoint.request.responseType - fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${awaitingPayloadType.name} on topic $topic") - iterateOnResponse(fiber, awaitingPayloadType, checkpoint.serialisedFiber, checkpoint.request) { - try { - Fiber.unparkDeserialized(fiber, scheduler) - } catch (e: Throwable) { - logError(e, it, topic, fiber) - } - } - } - is FiberRequest.NotExpectingResponse -> restoreNotExpectingResponse(fiber, checkpoint.request) - null -> restoreNotExpectingResponse(fiber) - else -> throw IllegalStateException("Unknown fiber request type " + checkpoint.request) - } - } - - /** - * Restore a Fiber which was not expecting a response (either because it hasn't asked for one, such as a new - * Fiber, or because the detail it needed has arrived). - */ - private fun restoreNotExpectingResponse(fiber: ProtocolStateMachineImpl<*>, request: FiberRequest? = null) { - val payload = request?.payload - fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${payload.toString().abbreviate(50)}") - executor.executeASAP { - iterateStateMachine(fiber, payload) { + if (checkpoint.request is ReceiveRequest<*>) { + val topicSession = checkpoint.request.receiveTopicSession + fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${checkpoint.request.receiveType.name} on $topicSession") + iterateOnResponse(fiber, checkpoint.serialisedFiber, checkpoint.request) { try { Fiber.unparkDeserialized(fiber, scheduler) } catch (e: Throwable) { - logError(e, it, null, fiber) + logError(e, it, topicSession, fiber) + } + } + } else { + fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${checkpoint.receivedPayload.toString().abbreviate(50)}") + executor.executeASAP { + iterateStateMachine(fiber, checkpoint.receivedPayload) { + try { + Fiber.unparkDeserialized(fiber, scheduler) + } catch (e: Throwable) { + logError(e, it, null, fiber) + } } } } @@ -220,7 +207,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName) // Need to add before iterating in case of immediate completion initFiber(fiber) { - val checkpoint = Checkpoint(serializeFiber(fiber), null) + val checkpoint = Checkpoint(serializeFiber(fiber), null, null) checkpointStorage.addCheckpoint(checkpoint) checkpoint } @@ -250,8 +237,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>, serialisedFiber: SerializedBytes>, - request: FiberRequest) { - val newCheckpoint = Checkpoint(serialisedFiber, request) + request: ProtocolIORequest?, + receivedPayload: Any?) { + val newCheckpoint = Checkpoint(serialisedFiber, request, receivedPayload) val previousCheckpoint = stateMachines.put(psm, newCheckpoint) if (previousCheckpoint != null) { checkpointStorage.removeCheckpoint(previousCheckpoint) @@ -269,62 +257,62 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService resumeAction(receivedPayload) } - private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: FiberRequest) { + private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: ProtocolIORequest) { // We have a request to do something: send, receive, or send-and-receive. - when (request) { - is FiberRequest.ExpectingResponse<*> -> { - // Prepare a listener on the network that runs in the background thread when we receive a message. - checkpointOnExpectingResponse(psm, request) + if (request is ReceiveRequest<*>) { + // Prepare a listener on the network that runs in the background thread when we receive a message. + prepareToReceiveForRequest(psm, request) + } + if (request is SendRequest) { + performSendRequest(psm, request) + } + } + + private fun prepareToReceiveForRequest(psm: ProtocolStateMachineImpl<*>, request: ReceiveRequest<*>) { + executor.checkOnThread() + val queueID = request.receiveTopicSession + val serialisedFiber = serializeFiber(psm) + updateCheckpoint(psm, serialisedFiber, request, null) + psm.logger.trace { "Preparing to receive message of type ${request.receiveType.name} on queue $queueID" } + iterateOnResponse(psm, serialisedFiber, request) { + try { + Fiber.unpark(psm, QUASAR_UNBLOCKER) + } catch(e: Throwable) { + logError(e, it, queueID, psm) } } - // If a non-null payload to send was provided, send it now. - val queueID = TopicSession(request.topic, request.sessionIDForSend) - request.payload?.let { - psm.logger.trace { "Sending message of type ${it.javaClass.name} using queue $queueID to ${request.destination} (${it.toString().abbreviate(50)})" } - val node = serviceHub.networkMapCache.getNodeByLegalName(request.destination!!.name) - if (node == null) { - throw IllegalArgumentException("Don't know about ${request.destination} but trying to send a message of type ${it.javaClass.name} on $queueID (${it.toString().abbreviate(50)})", request.stackTraceInCaseOfProblems) - } - serviceHub.networkService.send(queueID, it, node.address) - } - if (request is FiberRequest.NotExpectingResponse) { + } + + private fun performSendRequest(psm: ProtocolStateMachineImpl<*>, request: SendRequest) { + val topicSession = TopicSession(request.topic, request.sendSessionID) + val payload = request.payload + psm.logger.trace { "Sending message of type ${payload.javaClass.name} using queue $topicSession to ${request.destination} (${payload.toString().abbreviate(50)})" } + val node = serviceHub.networkMapCache.getNodeByLegalName(request.destination.name) ?: + throw IllegalArgumentException("Don't know about ${request.destination} but trying to send a message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})", request.stackTraceInCaseOfProblems) + serviceHub.networkService.send(topicSession, payload, node.address) + + if (request is SendOnly) { // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. iterateStateMachine(psm, null) { try { Fiber.unpark(psm, QUASAR_UNBLOCKER) } catch(e: Throwable) { - logError(e, request.payload, queueID, psm) + logError(e, request.payload, topicSession, psm) } } } } - private fun checkpointOnExpectingResponse(psm: ProtocolStateMachineImpl<*>, request: FiberRequest.ExpectingResponse<*>) { - executor.checkOnThread() - val queueID = request.receiveTopicSession - val serialisedFiber = serializeFiber(psm) - updateCheckpoint(psm, serialisedFiber, request) - psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on queue $queueID" } - iterateOnResponse(psm, request.responseType, serialisedFiber, request) { - try { - Fiber.unpark(psm, QUASAR_UNBLOCKER) - } catch(e: Throwable) { - logError(e, it, queueID, psm) - } - } - } - /** * Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is * received. */ private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>, - responseType: Class<*>, serialisedFiber: SerializedBytes>, - request: FiberRequest.ExpectingResponse<*>, + request: ReceiveRequest<*>, resumeAction: (Any?) -> Unit) { - val topic = request.receiveTopicSession - serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg -> + val topicSession = request.receiveTopicSession + serviceHub.networkService.runOnNextMessage(topicSession, executor) { netMsg -> // Assertion to ensure we don't execute on the wrong thread. executor.checkOnThread() // TODO: This is insecure: we should not deserialise whatever we find and *then* check. @@ -334,13 +322,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService // at the last moment when we do the downcast. However this would make protocol code harder to read and // make it more difficult to migrate to a more explicit serialisation scheme later. val payload = netMsg.data.deserialize() - check(responseType.isInstance(payload)) { "Expected message of type ${responseType.name} but got ${payload.javaClass.name}" } + check(request.receiveType.isInstance(payload)) { "Expected message of type ${request.receiveType.name} but got ${payload.javaClass.name}" } // Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload - updateCheckpoint(psm, serialisedFiber, request) - psm.logger.trace { "Received message of type ${payload.javaClass.name} on topic ${request.topic}.${request.sessionIDForReceive} (${payload.toString().abbreviate(50)})" } + updateCheckpoint(psm, serialisedFiber, null, payload) + psm.logger.trace { "Received message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})" } iterateStateMachine(psm, payload, resumeAction) } } } - -class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem") diff --git a/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt b/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt index cee289e5ee..9c76e8d7c4 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt @@ -3,10 +3,8 @@ package com.r3corda.node.services.persistence import com.google.common.jimfs.Configuration.unix import com.google.common.jimfs.Jimfs import com.google.common.primitives.Ints -import com.r3corda.core.random63BitValue import com.r3corda.core.serialization.SerializedBytes import com.r3corda.node.services.api.Checkpoint -import com.r3corda.node.services.statemachine.FiberRequest import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.After @@ -94,8 +92,6 @@ class PerFileCheckpointStorageTests { } private var checkpointCount = 1 - private val request = FiberRequest.ExpectingResponse("topic", null, random63BitValue(), random63BitValue(), null, - kotlin.String::class.java) - private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), request) + private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), null, null) } \ No newline at end of file diff --git a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt index 8151c016c8..0798fef0f0 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt @@ -2,62 +2,76 @@ package com.r3corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Suspendable +import com.r3corda.core.crypto.Party +import com.r3corda.core.node.NodeInfo import com.r3corda.core.protocols.ProtocolLogic -import com.r3corda.node.services.MockServiceHubInternal -import com.r3corda.node.services.api.Checkpoint -import com.r3corda.node.services.api.CheckpointStorage -import com.r3corda.node.services.api.MessagingServiceInternal -import com.r3corda.testing.node.InMemoryMessagingNetwork -import com.r3corda.node.utilities.AffinityExecutor +import com.r3corda.core.random63BitValue +import com.r3corda.testing.node.MockNetwork +import com.r3corda.testing.node.MockNetwork.MockNode import org.assertj.core.api.Assertions.assertThat import org.junit.After +import org.junit.Before import org.junit.Test -import java.util.* class StateMachineManagerTests { - val checkpointStorage = RecordingCheckpointStorage() - val network = InMemoryMessagingNetwork(false).InMemoryMessaging(true, InMemoryMessagingNetwork.Handle(1, "mock")) - val smm = createManager() + val net = MockNetwork() + lateinit var node1: MockNode + lateinit var node2: MockNode + + @Before + fun start() { + val nodes = net.createTwoNodes() + node1 = nodes.first + node2 = nodes.second + net.runNetwork() + } @After fun cleanUp() { - network.stop() + net.stopNodes() } @Test fun `newly added protocol is preserved on restart`() { - smm.add("test", ProtocolWithoutCheckpoints()) - // Ensure we're restoring from the original add checkpoint - assertThat(checkpointStorage.allCheckpoints).hasSize(1) - val restoredProtocol = createManager().run { - start() - findStateMachines(ProtocolWithoutCheckpoints::class.java).single().first - } + node1.smm.add("test", ProtocolWithoutCheckpoints()) + val restoredProtocol = node1.restartAndGetRestoredProtocol() assertThat(restoredProtocol.protocolStarted).isTrue() } @Test fun `protocol can lazily use the serviceHub in its constructor`() { val protocol = ProtocolWithLazyServiceHub() - smm.add("test", protocol) + node1.smm.add("test", protocol) assertThat(protocol.lazyTime).isNotNull() } - private fun createManager() = StateMachineManager(object : MockServiceHubInternal() { - override val networkService: MessagingServiceInternal get() = network - }, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD) + @Test + fun `protocol suspended just after receiving payload`() { + val topic = "send-and-receive" + val sessionID = random63BitValue() + val payload = random63BitValue() + node1.smm.add("test", SendProtocol(topic, node2.info.identity, sessionID, payload)) + node2.smm.add("test", ReceiveProtocol(topic, sessionID)) + net.runNetwork() + node2.stop() + val restoredProtocol = node2.restartAndGetRestoredProtocol(node1.info) + assertThat(restoredProtocol.receivedPayload).isEqualTo(payload) + } + + private inline fun MockNode.restartAndGetRestoredProtocol(networkMapAddress: NodeInfo? = null): P { + val node = mockNet.createNode(networkMapAddress, id) + return node.smm.findStateMachines(P::class.java).single().first + } - - private class ProtocolWithoutCheckpoints : ProtocolLogic() { + private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() { @Transient var protocolStarted = false @Suspendable - override fun call() { + override fun doCall() { protocolStarted = true - Fiber.park() } override val topic: String get() = throw UnsupportedOperationException() @@ -75,21 +89,37 @@ class StateMachineManagerTests { } - class RecordingCheckpointStorage : CheckpointStorage { + private class SendProtocol(override val topic: String, val destination: Party, val sessionID: Long, val payload: Any) : ProtocolLogic() { + @Suspendable + override fun call() = send(destination, sessionID, payload) + } - private val _checkpoints = ArrayList() - val allCheckpoints = ArrayList() - override fun addCheckpoint(checkpoint: Checkpoint) { - _checkpoints.add(checkpoint) - allCheckpoints.add(checkpoint) + private class ReceiveProtocol(override val topic: String, val sessionID: Long) : NonTerminatingProtocol() { + + @Transient var receivedPayload: Any? = null + + @Suspendable + override fun doCall() { + receivedPayload = receive(sessionID).validate { it } + } + } + + + /** + * A protocol that suspends forever after doing some work. This is to allow it to be retrieved from the SMM after + * restart for testing checkpoint restoration. Store any results as @Transient fields. + */ + private abstract class NonTerminatingProtocol : ProtocolLogic() { + + @Suspendable + override fun call() { + doCall() + Fiber.park() } - override fun removeCheckpoint(checkpoint: Checkpoint) { - _checkpoints.remove(checkpoint) - } - - override val checkpoints: Iterable get() = _checkpoints + @Suspendable + abstract fun doCall() } }