Refactored FiberRequest into cleaner ProtocolIORequest and fixed checkpoint regression

This commit is contained in:
Shams Asari 2016-09-01 19:47:55 +01:00
parent a2d7490902
commit 97e1a59770
8 changed files with 201 additions and 223 deletions

View File

@ -7,17 +7,20 @@ import com.r3corda.core.node.ServiceHub
import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.core.utilities.UntrustworthyData
import org.slf4j.Logger import org.slf4j.Logger
/** /**
* The interface of [ProtocolStateMachineImpl] exposing methods and properties required by ProtocolLogic for compilation. * The interface of [ProtocolStateMachineImpl] exposing methods and properties required by ProtocolLogic for compilation.
*/ */
interface ProtocolStateMachine<R> { interface ProtocolStateMachine<R> {
@Suspendable @Suspendable
fun <T : Any> sendAndReceive(topic: String, destination: Party, sessionIDForSend: Long, sessionIDForReceive: Long, fun <T : Any> sendAndReceive(topic: String,
payload: Any, recvType: Class<T>): UntrustworthyData<T> destination: Party,
sessionIDForSend: Long,
sessionIDForReceive: Long,
payload: Any,
receiveType: Class<T>): UntrustworthyData<T>
@Suspendable @Suspendable
fun <T : Any> receive(topic: String, sessionIDForReceive: Long, recvType: Class<T>): UntrustworthyData<T> fun <T : Any> receive(topic: String, sessionIDForReceive: Long, receiveType: Class<T>): UntrustworthyData<T>
@Suspendable @Suspendable
fun send(topic: String, destination: Party, sessionID: Long, payload: Any) fun send(topic: String, destination: Party, sessionID: Long, payload: Any)

View File

@ -1,7 +1,7 @@
package com.r3corda.node.services.api package com.r3corda.node.services.api
import com.r3corda.core.serialization.SerializedBytes import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.node.services.statemachine.FiberRequest import com.r3corda.node.services.statemachine.ProtocolIORequest
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
/** /**
@ -32,5 +32,6 @@ interface CheckpointStorage {
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo). // This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
data class Checkpoint( data class Checkpoint(
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>, val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
val request: FiberRequest? val request: ProtocolIORequest?,
val receivedPayload: Any?
) )

View File

@ -1,86 +0,0 @@
package com.r3corda.node.services.statemachine
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.TopicSession
// TODO: Clean this up
sealed class FiberRequest(val topic: String,
val destination: Party?,
val sessionIDForSend: Long,
val sessionIDForReceive: Long,
val payload: Any?) {
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
// don't have the original stack trace because it's in a suspended fiber.
@Transient
val stackTraceInCaseOfProblems: StackSnapshot? = StackSnapshot()
val receiveTopicSession: TopicSession
get() = TopicSession(topic, sessionIDForReceive)
override fun equals(other: Any?): Boolean
= if (other is FiberRequest) {
topic == other.topic
&& destination == other.destination
&& sessionIDForSend == other.sessionIDForSend
&& sessionIDForReceive == other.sessionIDForReceive
&& payload == other.payload
} else
false
override fun hashCode(): Int {
var hash = 1L
hash = (hash * 31) + topic.hashCode()
hash = (hash * 31) + if (destination == null)
0
else
destination.hashCode()
hash = (hash * 31) + sessionIDForReceive
hash = (hash * 31) + sessionIDForReceive
hash = (hash * 31) + if (payload == null)
0
else
payload.hashCode()
return hash.toInt()
}
/**
* A fiber which is expecting a message response.
*/
class ExpectingResponse<R : Any>(
topic: String,
destination: Party?,
sessionIDForSend: Long,
sessionIDForReceive: Long,
obj: Any?,
type: Class<R>
) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) {
private val responseTypeName: String = type.name
override fun equals(other: Any?): Boolean {
return if (other is ExpectingResponse<*>) {
super.equals(other) && responseTypeName == other.responseTypeName
} else
false
}
override fun toString(): String {
return "Expecting response via topic $receiveTopicSession of type $responseTypeName"
}
// We have to do an unchecked cast, but unless the serialized form is damaged, this was
// correct when the request was instantiated
@Suppress("UNCHECKED_CAST")
val responseType: Class<R>
get() = Class.forName(responseTypeName) as Class<R>
}
class NotExpectingResponse(
topic: String,
destination: Party,
sessionIDForSend: Long,
obj: Any?
) : FiberRequest(topic, destination, sessionIDForSend, -1, obj)
}

View File

@ -0,0 +1,51 @@
package com.r3corda.node.services.statemachine
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.TopicSession
// TODO revisit when Kotlin 1.1 is released and data classes can extend other classes
interface ProtocolIORequest {
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
// don't have the original stack trace because it's in a suspended fiber.
val stackTraceInCaseOfProblems: StackSnapshot
val topic: String
}
interface SendRequest : ProtocolIORequest {
val destination: Party
val payload: Any
val sendSessionID: Long
}
interface ReceiveRequest<T> : ProtocolIORequest {
val receiveType: Class<T>
val receiveSessionID: Long
val receiveTopicSession: TopicSession get() = TopicSession(topic, receiveSessionID)
}
data class SendAndReceive<T>(override val topic: String,
override val destination: Party,
override val payload: Any,
override val sendSessionID: Long,
override val receiveType: Class<T>,
override val receiveSessionID: Long) : SendRequest, ReceiveRequest<T> {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
data class ReceiveOnly<T>(override val topic: String,
override val receiveType: Class<T>,
override val receiveSessionID: Long) : ReceiveRequest<T> {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
data class SendOnly(override val destination: Party,
override val topic: String,
override val payload: Any,
override val sendSessionID: Long) : SendRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")

View File

@ -29,7 +29,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
// These fields shouldn't be serialised, so they are marked @Transient. // These fields shouldn't be serialised, so they are marked @Transient.
@Transient lateinit override var serviceHub: ServiceHubInternal @Transient lateinit override var serviceHub: ServiceHubInternal
@Transient internal lateinit var suspendAction: (FiberRequest) -> Unit @Transient internal lateinit var suspendAction: (ProtocolIORequest) -> Unit
@Transient internal lateinit var actionOnEnd: () -> Unit @Transient internal lateinit var actionOnEnd: () -> Unit
@Transient internal var receivedPayload: Any? = null @Transient internal var receivedPayload: Any? = null
@ -77,43 +77,40 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
return result return result
} }
@Suspendable @Suppress("UNCHECKED_CAST") @Suspendable
private fun <T : Any> suspendAndExpectReceive(with: FiberRequest): UntrustworthyData<T> { private fun <T : Any> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): UntrustworthyData<T> {
suspend(with) suspend(receiveRequest)
check(receivedPayload != null) { "Expected to receive something" } check(receivedPayload != null) { "Expected to receive something" }
val untrustworthy = UntrustworthyData(receivedPayload as T) val untrustworthy = UntrustworthyData(receiveRequest.receiveType.cast(receivedPayload))
receivedPayload = null receivedPayload = null
return untrustworthy return untrustworthy
} }
@Suspendable @Suppress("UNCHECKED_CAST") @Suspendable
override fun <T : Any> sendAndReceive(topic: String, override fun <T : Any> sendAndReceive(topic: String,
destination: Party, destination: Party,
sessionIDForSend: Long, sessionIDForSend: Long,
sessionIDForReceive: Long, sessionIDForReceive: Long,
payload: Any, payload: Any,
recvType: Class<T>): UntrustworthyData<T> { receiveType: Class<T>): UntrustworthyData<T> {
val result = FiberRequest.ExpectingResponse(topic, destination, sessionIDForSend, sessionIDForReceive, payload, recvType) return suspendAndExpectReceive(SendAndReceive(topic, destination, payload, sessionIDForSend, receiveType, sessionIDForReceive))
return suspendAndExpectReceive(result)
} }
@Suspendable @Suspendable
override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, recvType: Class<T>): UntrustworthyData<T> { override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, receiveType: Class<T>): UntrustworthyData<T> {
val result = FiberRequest.ExpectingResponse(topic, null, -1, sessionIDForReceive, null, recvType) return suspendAndExpectReceive(ReceiveOnly(topic, receiveType, sessionIDForReceive))
return suspendAndExpectReceive(result)
} }
@Suspendable @Suspendable
override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) { override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) {
val result = FiberRequest.NotExpectingResponse(topic, destination, sessionID, payload) suspend(SendOnly(destination, topic, payload, sessionID))
suspend(result)
} }
@Suspendable @Suspendable
private fun suspend(with: FiberRequest) { private fun suspend(protocolIORequest: ProtocolIORequest) {
parkAndSerialize { fiber, serializer -> parkAndSerialize { fiber, serializer ->
try { try {
suspendAction(with) suspendAction(protocolIORequest)
} catch (t: Throwable) { } catch (t: Throwable) {
// Do not throw exception again - Quasar completely bins it. // Do not throw exception again - Quasar completely bins it.
logger.warn("Captured exception which was swallowed by Quasar", t) logger.warn("Captured exception which was swallowed by Quasar", t)

View File

@ -125,38 +125,25 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
val fiber = deserializeFiber(checkpoint.serialisedFiber) val fiber = deserializeFiber(checkpoint.serialisedFiber)
initFiber(fiber, { checkpoint }) initFiber(fiber, { checkpoint })
when (checkpoint.request) { if (checkpoint.request is ReceiveRequest<*>) {
is FiberRequest.ExpectingResponse<*> -> { val topicSession = checkpoint.request.receiveTopicSession
val topic = checkpoint.request.receiveTopicSession fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${checkpoint.request.receiveType.name} on $topicSession")
val awaitingPayloadType = checkpoint.request.responseType iterateOnResponse(fiber, checkpoint.serialisedFiber, checkpoint.request) {
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${awaitingPayloadType.name} on topic $topic")
iterateOnResponse(fiber, awaitingPayloadType, checkpoint.serialisedFiber, checkpoint.request) {
try {
Fiber.unparkDeserialized(fiber, scheduler)
} catch (e: Throwable) {
logError(e, it, topic, fiber)
}
}
}
is FiberRequest.NotExpectingResponse -> restoreNotExpectingResponse(fiber, checkpoint.request)
null -> restoreNotExpectingResponse(fiber)
else -> throw IllegalStateException("Unknown fiber request type " + checkpoint.request)
}
}
/**
* Restore a Fiber which was not expecting a response (either because it hasn't asked for one, such as a new
* Fiber, or because the detail it needed has arrived).
*/
private fun restoreNotExpectingResponse(fiber: ProtocolStateMachineImpl<*>, request: FiberRequest? = null) {
val payload = request?.payload
fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${payload.toString().abbreviate(50)}")
executor.executeASAP {
iterateStateMachine(fiber, payload) {
try { try {
Fiber.unparkDeserialized(fiber, scheduler) Fiber.unparkDeserialized(fiber, scheduler)
} catch (e: Throwable) { } catch (e: Throwable) {
logError(e, it, null, fiber) logError(e, it, topicSession, fiber)
}
}
} else {
fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${checkpoint.receivedPayload.toString().abbreviate(50)}")
executor.executeASAP {
iterateStateMachine(fiber, checkpoint.receivedPayload) {
try {
Fiber.unparkDeserialized(fiber, scheduler)
} catch (e: Throwable) {
logError(e, it, null, fiber)
}
} }
} }
} }
@ -220,7 +207,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName) val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
// Need to add before iterating in case of immediate completion // Need to add before iterating in case of immediate completion
initFiber(fiber) { initFiber(fiber) {
val checkpoint = Checkpoint(serializeFiber(fiber), null) val checkpoint = Checkpoint(serializeFiber(fiber), null, null)
checkpointStorage.addCheckpoint(checkpoint) checkpointStorage.addCheckpoint(checkpoint)
checkpoint checkpoint
} }
@ -250,8 +237,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>, private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>,
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>, serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
request: FiberRequest) { request: ProtocolIORequest?,
val newCheckpoint = Checkpoint(serialisedFiber, request) receivedPayload: Any?) {
val newCheckpoint = Checkpoint(serialisedFiber, request, receivedPayload)
val previousCheckpoint = stateMachines.put(psm, newCheckpoint) val previousCheckpoint = stateMachines.put(psm, newCheckpoint)
if (previousCheckpoint != null) { if (previousCheckpoint != null) {
checkpointStorage.removeCheckpoint(previousCheckpoint) checkpointStorage.removeCheckpoint(previousCheckpoint)
@ -269,62 +257,62 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
resumeAction(receivedPayload) resumeAction(receivedPayload)
} }
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: FiberRequest) { private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: ProtocolIORequest) {
// We have a request to do something: send, receive, or send-and-receive. // We have a request to do something: send, receive, or send-and-receive.
when (request) { if (request is ReceiveRequest<*>) {
is FiberRequest.ExpectingResponse<*> -> { // Prepare a listener on the network that runs in the background thread when we receive a message.
// Prepare a listener on the network that runs in the background thread when we receive a message. prepareToReceiveForRequest(psm, request)
checkpointOnExpectingResponse(psm, request) }
if (request is SendRequest) {
performSendRequest(psm, request)
}
}
private fun prepareToReceiveForRequest(psm: ProtocolStateMachineImpl<*>, request: ReceiveRequest<*>) {
executor.checkOnThread()
val queueID = request.receiveTopicSession
val serialisedFiber = serializeFiber(psm)
updateCheckpoint(psm, serialisedFiber, request, null)
psm.logger.trace { "Preparing to receive message of type ${request.receiveType.name} on queue $queueID" }
iterateOnResponse(psm, serialisedFiber, request) {
try {
Fiber.unpark(psm, QUASAR_UNBLOCKER)
} catch(e: Throwable) {
logError(e, it, queueID, psm)
} }
} }
// If a non-null payload to send was provided, send it now. }
val queueID = TopicSession(request.topic, request.sessionIDForSend)
request.payload?.let { private fun performSendRequest(psm: ProtocolStateMachineImpl<*>, request: SendRequest) {
psm.logger.trace { "Sending message of type ${it.javaClass.name} using queue $queueID to ${request.destination} (${it.toString().abbreviate(50)})" } val topicSession = TopicSession(request.topic, request.sendSessionID)
val node = serviceHub.networkMapCache.getNodeByLegalName(request.destination!!.name) val payload = request.payload
if (node == null) { psm.logger.trace { "Sending message of type ${payload.javaClass.name} using queue $topicSession to ${request.destination} (${payload.toString().abbreviate(50)})" }
throw IllegalArgumentException("Don't know about ${request.destination} but trying to send a message of type ${it.javaClass.name} on $queueID (${it.toString().abbreviate(50)})", request.stackTraceInCaseOfProblems) val node = serviceHub.networkMapCache.getNodeByLegalName(request.destination.name) ?:
} throw IllegalArgumentException("Don't know about ${request.destination} but trying to send a message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})", request.stackTraceInCaseOfProblems)
serviceHub.networkService.send(queueID, it, node.address) serviceHub.networkService.send(topicSession, payload, node.address)
}
if (request is FiberRequest.NotExpectingResponse) { if (request is SendOnly) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
iterateStateMachine(psm, null) { iterateStateMachine(psm, null) {
try { try {
Fiber.unpark(psm, QUASAR_UNBLOCKER) Fiber.unpark(psm, QUASAR_UNBLOCKER)
} catch(e: Throwable) { } catch(e: Throwable) {
logError(e, request.payload, queueID, psm) logError(e, request.payload, topicSession, psm)
} }
} }
} }
} }
private fun checkpointOnExpectingResponse(psm: ProtocolStateMachineImpl<*>, request: FiberRequest.ExpectingResponse<*>) {
executor.checkOnThread()
val queueID = request.receiveTopicSession
val serialisedFiber = serializeFiber(psm)
updateCheckpoint(psm, serialisedFiber, request)
psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on queue $queueID" }
iterateOnResponse(psm, request.responseType, serialisedFiber, request) {
try {
Fiber.unpark(psm, QUASAR_UNBLOCKER)
} catch(e: Throwable) {
logError(e, it, queueID, psm)
}
}
}
/** /**
* Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is * Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is
* received. * received.
*/ */
private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>, private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>,
responseType: Class<*>,
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>, serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
request: FiberRequest.ExpectingResponse<*>, request: ReceiveRequest<*>,
resumeAction: (Any?) -> Unit) { resumeAction: (Any?) -> Unit) {
val topic = request.receiveTopicSession val topicSession = request.receiveTopicSession
serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg -> serviceHub.networkService.runOnNextMessage(topicSession, executor) { netMsg ->
// Assertion to ensure we don't execute on the wrong thread. // Assertion to ensure we don't execute on the wrong thread.
executor.checkOnThread() executor.checkOnThread()
// TODO: This is insecure: we should not deserialise whatever we find and *then* check. // TODO: This is insecure: we should not deserialise whatever we find and *then* check.
@ -334,13 +322,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
// at the last moment when we do the downcast. However this would make protocol code harder to read and // at the last moment when we do the downcast. However this would make protocol code harder to read and
// make it more difficult to migrate to a more explicit serialisation scheme later. // make it more difficult to migrate to a more explicit serialisation scheme later.
val payload = netMsg.data.deserialize<Any>() val payload = netMsg.data.deserialize<Any>()
check(responseType.isInstance(payload)) { "Expected message of type ${responseType.name} but got ${payload.javaClass.name}" } check(request.receiveType.isInstance(payload)) { "Expected message of type ${request.receiveType.name} but got ${payload.javaClass.name}" }
// Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload // Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload
updateCheckpoint(psm, serialisedFiber, request) updateCheckpoint(psm, serialisedFiber, null, payload)
psm.logger.trace { "Received message of type ${payload.javaClass.name} on topic ${request.topic}.${request.sessionIDForReceive} (${payload.toString().abbreviate(50)})" } psm.logger.trace { "Received message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})" }
iterateStateMachine(psm, payload, resumeAction) iterateStateMachine(psm, payload, resumeAction)
} }
} }
} }
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")

View File

@ -3,10 +3,8 @@ package com.r3corda.node.services.persistence
import com.google.common.jimfs.Configuration.unix import com.google.common.jimfs.Configuration.unix
import com.google.common.jimfs.Jimfs import com.google.common.jimfs.Jimfs
import com.google.common.primitives.Ints import com.google.common.primitives.Ints
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.SerializedBytes import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.node.services.api.Checkpoint import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.statemachine.FiberRequest
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After import org.junit.After
@ -94,8 +92,6 @@ class PerFileCheckpointStorageTests {
} }
private var checkpointCount = 1 private var checkpointCount = 1
private val request = FiberRequest.ExpectingResponse("topic", null, random63BitValue(), random63BitValue(), null, private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), null, null)
kotlin.String::class.java)
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), request)
} }

View File

@ -2,62 +2,76 @@ package com.r3corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.node.services.MockServiceHubInternal import com.r3corda.core.random63BitValue
import com.r3corda.node.services.api.Checkpoint import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.api.CheckpointStorage import com.r3corda.testing.node.MockNetwork.MockNode
import com.r3corda.node.services.api.MessagingServiceInternal
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.node.utilities.AffinityExecutor
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.After import org.junit.After
import org.junit.Before
import org.junit.Test import org.junit.Test
import java.util.*
class StateMachineManagerTests { class StateMachineManagerTests {
val checkpointStorage = RecordingCheckpointStorage() val net = MockNetwork()
val network = InMemoryMessagingNetwork(false).InMemoryMessaging(true, InMemoryMessagingNetwork.Handle(1, "mock")) lateinit var node1: MockNode
val smm = createManager() lateinit var node2: MockNode
@Before
fun start() {
val nodes = net.createTwoNodes()
node1 = nodes.first
node2 = nodes.second
net.runNetwork()
}
@After @After
fun cleanUp() { fun cleanUp() {
network.stop() net.stopNodes()
} }
@Test @Test
fun `newly added protocol is preserved on restart`() { fun `newly added protocol is preserved on restart`() {
smm.add("test", ProtocolWithoutCheckpoints()) node1.smm.add("test", ProtocolWithoutCheckpoints())
// Ensure we're restoring from the original add checkpoint val restoredProtocol = node1.restartAndGetRestoredProtocol<ProtocolWithoutCheckpoints>()
assertThat(checkpointStorage.allCheckpoints).hasSize(1)
val restoredProtocol = createManager().run {
start()
findStateMachines(ProtocolWithoutCheckpoints::class.java).single().first
}
assertThat(restoredProtocol.protocolStarted).isTrue() assertThat(restoredProtocol.protocolStarted).isTrue()
} }
@Test @Test
fun `protocol can lazily use the serviceHub in its constructor`() { fun `protocol can lazily use the serviceHub in its constructor`() {
val protocol = ProtocolWithLazyServiceHub() val protocol = ProtocolWithLazyServiceHub()
smm.add("test", protocol) node1.smm.add("test", protocol)
assertThat(protocol.lazyTime).isNotNull() assertThat(protocol.lazyTime).isNotNull()
} }
private fun createManager() = StateMachineManager(object : MockServiceHubInternal() { @Test
override val networkService: MessagingServiceInternal get() = network fun `protocol suspended just after receiving payload`() {
}, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD) val topic = "send-and-receive"
val sessionID = random63BitValue()
val payload = random63BitValue()
node1.smm.add("test", SendProtocol(topic, node2.info.identity, sessionID, payload))
node2.smm.add("test", ReceiveProtocol(topic, sessionID))
net.runNetwork()
node2.stop()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info)
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
}
private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: NodeInfo? = null): P {
val node = mockNet.createNode(networkMapAddress, id)
return node.smm.findStateMachines(P::class.java).single().first
}
private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() {
private class ProtocolWithoutCheckpoints : ProtocolLogic<Unit>() {
@Transient var protocolStarted = false @Transient var protocolStarted = false
@Suspendable @Suspendable
override fun call() { override fun doCall() {
protocolStarted = true protocolStarted = true
Fiber.park()
} }
override val topic: String get() = throw UnsupportedOperationException() override val topic: String get() = throw UnsupportedOperationException()
@ -75,21 +89,37 @@ class StateMachineManagerTests {
} }
class RecordingCheckpointStorage : CheckpointStorage { private class SendProtocol(override val topic: String, val destination: Party, val sessionID: Long, val payload: Any) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() = send(destination, sessionID, payload)
}
private val _checkpoints = ArrayList<Checkpoint>()
val allCheckpoints = ArrayList<Checkpoint>()
override fun addCheckpoint(checkpoint: Checkpoint) { private class ReceiveProtocol(override val topic: String, val sessionID: Long) : NonTerminatingProtocol() {
_checkpoints.add(checkpoint)
allCheckpoints.add(checkpoint) @Transient var receivedPayload: Any? = null
@Suspendable
override fun doCall() {
receivedPayload = receive<Any>(sessionID).validate { it }
}
}
/**
* A protocol that suspends forever after doing some work. This is to allow it to be retrieved from the SMM after
* restart for testing checkpoint restoration. Store any results as @Transient fields.
*/
private abstract class NonTerminatingProtocol : ProtocolLogic<Unit>() {
@Suspendable
override fun call() {
doCall()
Fiber.park()
} }
override fun removeCheckpoint(checkpoint: Checkpoint) { @Suspendable
_checkpoints.remove(checkpoint) abstract fun doCall()
}
override val checkpoints: Iterable<Checkpoint> get() = _checkpoints
} }
} }