From 31ee8ab60be2877615b336488c7b6771f30ecb5e Mon Sep 17 00:00:00 2001 From: Ross Nicoll Date: Fri, 22 Jul 2016 16:09:52 +0100 Subject: [PATCH] Rework checkpoint storage to include the FiberRequest Rework checkpoint storage to include the FiberRequest, so that different requests can be supported. --- .../node/services/api/CheckpointStorage.kt | 5 +- .../statemachine/StateMachineManager.kt | 140 +++++++++++++----- .../PerFileCheckpointStorageTests.kt | 6 +- 3 files changed, 111 insertions(+), 40 deletions(-) diff --git a/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt b/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt index f80d0dcdfe..4128faeb70 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt @@ -2,6 +2,7 @@ package com.r3corda.node.services.api import com.r3corda.core.serialization.SerializedBytes import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl +import com.r3corda.node.services.statemachine.StateMachineManager /** * Thread-safe storage of fiber checkpoints. @@ -31,7 +32,5 @@ interface CheckpointStorage { // This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo). data class Checkpoint( val serialisedFiber: SerializedBytes>, - val awaitingTopic: String?, - val awaitingPayloadType: String?, - val receivedPayload: Any? + val request: StateMachineManager.FiberRequest? ) \ No newline at end of file diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt index 877a09af26..7be16a5e43 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt @@ -119,28 +119,40 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService val fiber = deserializeFiber(checkpoint.serialisedFiber) initFiber(fiber, { checkpoint }) - val topic = checkpoint.awaitingTopic - if (topic != null) { - val awaitingPayloadType = Class.forName(checkpoint.awaitingPayloadType) - fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${awaitingPayloadType.name} on topic $topic") - iterateOnResponse(fiber, awaitingPayloadType, checkpoint.serialisedFiber, topic) { - try { - Fiber.unparkDeserialized(fiber, scheduler) - } catch (e: Throwable) { - logError(e, it, topic, fiber) - } - } - } else { - fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${checkpoint.receivedPayload.toString().abbreviate(50)}") - executor.executeASAP { - iterateStateMachine(fiber, checkpoint.receivedPayload) { + when (checkpoint.request) { + is FiberRequest.ExpectingResponse<*> -> { + val topic = checkpoint.request.receiveTopic + val awaitingPayloadType = checkpoint.request.responseType + fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${awaitingPayloadType.name} on topic $topic") + iterateOnResponse(fiber, awaitingPayloadType, checkpoint.serialisedFiber, checkpoint.request) { try { Fiber.unparkDeserialized(fiber, scheduler) } catch (e: Throwable) { - logError(e, it, null, fiber) + logError(e, it, topic, fiber) } } } + is FiberRequest.NotExpectingResponse -> restoreNotExpectingResponse(fiber, checkpoint.request) + null -> restoreNotExpectingResponse(fiber) + else -> throw IllegalStateException("Unknown fiber request type " + checkpoint.request) + } + } + + /** + * Restore a Fiber which was not expecting a response (either because it hasn't asked for one, such as a new + * Fiber, or because the detail it needed has arrived). + */ + private fun restoreNotExpectingResponse(fiber: ProtocolStateMachineImpl<*>, request: FiberRequest? = null) { + val payload = request?.payload + fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${payload.toString().abbreviate(50)}") + executor.executeASAP { + iterateStateMachine(fiber, payload) { + try { + Fiber.unparkDeserialized(fiber, scheduler) + } catch (e: Throwable) { + logError(e, it, null, fiber) + } + } } } @@ -203,7 +215,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName) // Need to add before iterating in case of immediate completion initFiber(fiber) { - val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null) + val checkpoint = Checkpoint(serializeFiber(fiber), null) checkpointStorage.addCheckpoint(checkpoint) checkpoint } @@ -222,10 +234,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>, serialisedFiber: SerializedBytes>, - awaitingTopic: String?, - awaitingPayloadType: Class<*>?, - receivedPayload: Any?) { - val newCheckpoint = Checkpoint(serialisedFiber, awaitingTopic, awaitingPayloadType?.name, receivedPayload) + request: FiberRequest) { + val newCheckpoint = Checkpoint(serialisedFiber, request) val previousCheckpoint = stateMachines.put(psm, newCheckpoint) if (previousCheckpoint != null) { checkpointStorage.removeCheckpoint(previousCheckpoint) @@ -245,9 +255,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: FiberRequest) { // We have a request to do something: send, receive, or send-and-receive. - if (request is FiberRequest.ExpectingResponse<*>) { - // Prepare a listener on the network that runs in the background thread when we receive a message. - checkpointOnExpectingResponse(psm, request) + when (request) { + is FiberRequest.ExpectingResponse<*> -> { + // Prepare a listener on the network that runs in the background thread when we receive a message. + checkpointOnExpectingResponse(psm, request) + } } // If a non-null payload to send was provided, send it now. request.payload?.let { @@ -273,9 +285,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService executor.checkOnThread() val topic = "${request.topic}.${request.sessionIDForReceive}" val serialisedFiber = serializeFiber(psm) - updateCheckpoint(psm, serialisedFiber, topic, request.responseType, null) + updateCheckpoint(psm, serialisedFiber, request) psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on topic $topic" } - iterateOnResponse(psm, request.responseType, serialisedFiber, topic) { + iterateOnResponse(psm, request.responseType, serialisedFiber, request) { try { Fiber.unpark(psm, QUASAR_UNBLOCKER) } catch(e: Throwable) { @@ -284,11 +296,16 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService } } + /** + * Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is + * received. + */ private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>, responseType: Class<*>, serialisedFiber: SerializedBytes>, - topic: String, + request: StateMachineManager.FiberRequest.ExpectingResponse<*>, resumeAction: (Any?) -> Unit) { + val topic = request.receiveTopic serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg -> // Assertion to ensure we don't execute on the wrong thread. executor.checkOnThread() @@ -301,30 +318,81 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService val payload = netMsg.data.deserialize() check(responseType.isInstance(payload)) { "Expected message of type ${responseType.name} but got ${payload.javaClass.name}" } // Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload - updateCheckpoint(psm, serialisedFiber, null, null, payload) + updateCheckpoint(psm, serialisedFiber, request) psm.logger.trace { "Received message of type ${payload.javaClass.name} on topic $topic (${payload.toString().abbreviate(50)})" } iterateStateMachine(psm, payload, resumeAction) } } // TODO: Clean this up - open class FiberRequest(val topic: String, - val destination: Party?, - val sessionIDForSend: Long, - val sessionIDForReceive: Long, - val payload: Any?) { + sealed class FiberRequest(val topic: String, + val destination: Party?, + val sessionIDForSend: Long, + val sessionIDForReceive: Long, + val payload: Any?) { // This is used to identify where we suspended, in case of message mismatch errors and other things where we // don't have the original stack trace because it's in a suspended fiber. - val stackTraceInCaseOfProblems = StackSnapshot() + @Transient + val stackTraceInCaseOfProblems: StackSnapshot? = StackSnapshot() + val receiveTopic: String + get() = topic + "." + sessionIDForReceive + + + override fun equals(other: Any?): Boolean + = if (other is FiberRequest) { + topic == other.topic + && destination == other.destination + && sessionIDForSend == other.sessionIDForSend + && sessionIDForReceive == other.sessionIDForReceive + && payload == other.payload + } else + false + + override fun hashCode(): Int { + var hash = 1L + + hash = (hash * 31) + topic.hashCode() + hash = (hash * 31) + if (destination == null) + 0 + else + destination.hashCode() + hash = (hash * 31) + sessionIDForReceive + hash = (hash * 31) + sessionIDForReceive + hash = (hash * 31) + if (payload == null) + 0 + else + payload.hashCode() + + return hash.toInt() + } + + /** + * A fiber which is expecting a message response. + */ class ExpectingResponse( topic: String, destination: Party?, sessionIDForSend: Long, sessionIDForReceive: Long, obj: Any?, - val responseType: Class - ) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) + type: Class + ) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) { + private val responseTypeName: String = type.name + + override fun equals(other: Any?): Boolean + = if (other is FiberRequest.ExpectingResponse<*>) { + super.equals(other) + && responseTypeName == other.responseTypeName + } else + false + + override fun toString(): String { + return "Expecting response via topic ${receiveTopic} of type ${responseTypeName}" + } + val responseType: Class + get() = Class.forName(responseTypeName) as Class + } class NotExpectingResponse( topic: String, diff --git a/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt b/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt index 93149ba6f5..35ed28cb30 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt @@ -3,8 +3,10 @@ package com.r3corda.node.services.persistence import com.google.common.jimfs.Configuration.unix import com.google.common.jimfs.Jimfs import com.google.common.primitives.Ints +import com.r3corda.core.random63BitValue import com.r3corda.core.serialization.SerializedBytes import com.r3corda.node.services.api.Checkpoint +import com.r3corda.node.services.statemachine.StateMachineManager import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.After @@ -92,6 +94,8 @@ class PerFileCheckpointStorageTests { } private var checkpointCount = 1 - private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), "topic", "javaType", null) + private val request = StateMachineManager.FiberRequest.ExpectingResponse("topic", null, random63BitValue(), random63BitValue(), null, + java.lang.String::class.java) + private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), request) } \ No newline at end of file