diff --git a/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt b/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt index f80d0dcdfe..96945d52ad 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt @@ -1,7 +1,9 @@ package com.r3corda.node.services.api import com.r3corda.core.serialization.SerializedBytes +import com.r3corda.node.services.statemachine.FiberRequest import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl +import com.r3corda.node.services.statemachine.StateMachineManager /** * Thread-safe storage of fiber checkpoints. @@ -31,7 +33,5 @@ interface CheckpointStorage { // This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo). data class Checkpoint( val serialisedFiber: SerializedBytes>, - val awaitingTopic: String?, - val awaitingPayloadType: String?, - val receivedPayload: Any? + val request: FiberRequest? ) \ No newline at end of file diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/FiberRequest.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/FiberRequest.kt new file mode 100644 index 0000000000..1b1ac5b9ad --- /dev/null +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/FiberRequest.kt @@ -0,0 +1,81 @@ +package com.r3corda.node.services.statemachine + +import com.r3corda.core.crypto.Party + +// TODO: Clean this up +sealed class FiberRequest(val topic: String, + val destination: Party?, + val sessionIDForSend: Long, + val sessionIDForReceive: Long, + val payload: Any?) { + // This is used to identify where we suspended, in case of message mismatch errors and other things where we + // don't have the original stack trace because it's in a suspended fiber. + @Transient + val stackTraceInCaseOfProblems: StackSnapshot? = StackSnapshot() + + val receiveTopic: String + get() = topic + "." + sessionIDForReceive + + + override fun equals(other: Any?): Boolean + = if (other is FiberRequest) { + topic == other.topic + && destination == other.destination + && sessionIDForSend == other.sessionIDForSend + && sessionIDForReceive == other.sessionIDForReceive + && payload == other.payload + } else + false + + override fun hashCode(): Int { + var hash = 1L + + hash = (hash * 31) + topic.hashCode() + hash = (hash * 31) + if (destination == null) + 0 + else + destination.hashCode() + hash = (hash * 31) + sessionIDForReceive + hash = (hash * 31) + sessionIDForReceive + hash = (hash * 31) + if (payload == null) + 0 + else + payload.hashCode() + + return hash.toInt() + } + + /** + * A fiber which is expecting a message response. + */ + class ExpectingResponse( + topic: String, + destination: Party?, + sessionIDForSend: Long, + sessionIDForReceive: Long, + obj: Any?, + type: Class + ) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) { + private val responseTypeName: String = type.name + + override fun equals(other: Any?): Boolean + = if (other is ExpectingResponse<*>) { + super.equals(other) + && responseTypeName == other.responseTypeName + } else + false + + override fun toString(): String { + return "Expecting response via topic ${receiveTopic} of type ${responseTypeName}" + } + val responseType: Class + get() = Class.forName(responseTypeName) as Class + } + + class NotExpectingResponse( + topic: String, + destination: Party, + sessionIDForSend: Long, + obj: Any? + ) : FiberRequest(topic, destination, sessionIDForSend, -1, obj) +} \ No newline at end of file diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt index 43a5aa98cc..85f2916c32 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt @@ -28,7 +28,7 @@ class ProtocolStateMachineImpl(val logic: ProtocolLogic, // These fields shouldn't be serialised, so they are marked @Transient. @Transient lateinit override var serviceHub: ServiceHubInternal - @Transient internal lateinit var suspendAction: (StateMachineManager.FiberRequest) -> Unit + @Transient internal lateinit var suspendAction: (FiberRequest) -> Unit @Transient internal lateinit var actionOnEnd: () -> Unit @Transient internal var receivedPayload: Any? = null @@ -72,7 +72,7 @@ class ProtocolStateMachineImpl(val logic: ProtocolLogic, } @Suspendable @Suppress("UNCHECKED_CAST") - private fun suspendAndExpectReceive(with: StateMachineManager.FiberRequest): UntrustworthyData { + private fun suspendAndExpectReceive(with: FiberRequest): UntrustworthyData { suspend(with) check(receivedPayload != null) { "Expected to receive something" } val untrustworthy = UntrustworthyData(receivedPayload as T) @@ -87,24 +87,24 @@ class ProtocolStateMachineImpl(val logic: ProtocolLogic, sessionIDForReceive: Long, payload: Any, recvType: Class): UntrustworthyData { - val result = StateMachineManager.FiberRequest.ExpectingResponse(topic, destination, sessionIDForSend, sessionIDForReceive, payload, recvType) + val result = FiberRequest.ExpectingResponse(topic, destination, sessionIDForSend, sessionIDForReceive, payload, recvType) return suspendAndExpectReceive(result) } @Suspendable override fun receive(topic: String, sessionIDForReceive: Long, recvType: Class): UntrustworthyData { - val result = StateMachineManager.FiberRequest.ExpectingResponse(topic, null, -1, sessionIDForReceive, null, recvType) + val result = FiberRequest.ExpectingResponse(topic, null, -1, sessionIDForReceive, null, recvType) return suspendAndExpectReceive(result) } @Suspendable override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) { - val result = StateMachineManager.FiberRequest.NotExpectingResponse(topic, destination, sessionID, payload) + val result = FiberRequest.NotExpectingResponse(topic, destination, sessionID, payload) suspend(result) } @Suspendable - private fun suspend(with: StateMachineManager.FiberRequest) { + private fun suspend(with: FiberRequest) { parkAndSerialize { fiber, serializer -> try { suspendAction(with) diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt index cd8d74da65..23df79646f 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt @@ -8,7 +8,6 @@ import com.esotericsoftware.kryo.Kryo import com.google.common.base.Throwables import com.google.common.util.concurrent.ListenableFuture import com.r3corda.core.abbreviate -import com.r3corda.core.crypto.Party import com.r3corda.core.messaging.runOnNextMessage import com.r3corda.core.messaging.send import com.r3corda.core.protocols.ProtocolLogic @@ -123,28 +122,40 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService val fiber = deserializeFiber(checkpoint.serialisedFiber) initFiber(fiber, { checkpoint }) - val topic = checkpoint.awaitingTopic - if (topic != null) { - val awaitingPayloadType = Class.forName(checkpoint.awaitingPayloadType) - fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${awaitingPayloadType.name} on topic $topic") - iterateOnResponse(fiber, awaitingPayloadType, checkpoint.serialisedFiber, topic) { - try { - Fiber.unparkDeserialized(fiber, scheduler) - } catch (e: Throwable) { - logError(e, it, topic, fiber) - } - } - } else { - fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${checkpoint.receivedPayload.toString().abbreviate(50)}") - executor.executeASAP { - iterateStateMachine(fiber, checkpoint.receivedPayload) { + when (checkpoint.request) { + is FiberRequest.ExpectingResponse<*> -> { + val topic = checkpoint.request.receiveTopic + val awaitingPayloadType = checkpoint.request.responseType + fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${awaitingPayloadType.name} on topic $topic") + iterateOnResponse(fiber, awaitingPayloadType, checkpoint.serialisedFiber, checkpoint.request) { try { Fiber.unparkDeserialized(fiber, scheduler) } catch (e: Throwable) { - logError(e, it, null, fiber) + logError(e, it, topic, fiber) } } } + is FiberRequest.NotExpectingResponse -> restoreNotExpectingResponse(fiber, checkpoint.request) + null -> restoreNotExpectingResponse(fiber) + else -> throw IllegalStateException("Unknown fiber request type " + checkpoint.request) + } + } + + /** + * Restore a Fiber which was not expecting a response (either because it hasn't asked for one, such as a new + * Fiber, or because the detail it needed has arrived). + */ + private fun restoreNotExpectingResponse(fiber: ProtocolStateMachineImpl<*>, request: FiberRequest? = null) { + val payload = request?.payload + fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${payload.toString().abbreviate(50)}") + executor.executeASAP { + iterateStateMachine(fiber, payload) { + try { + Fiber.unparkDeserialized(fiber, scheduler) + } catch (e: Throwable) { + logError(e, it, null, fiber) + } + } } } @@ -207,7 +218,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName) // Need to add before iterating in case of immediate completion initFiber(fiber) { - val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null) + val checkpoint = Checkpoint(serializeFiber(fiber), null) checkpointStorage.addCheckpoint(checkpoint) checkpoint } @@ -226,10 +237,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>, serialisedFiber: SerializedBytes>, - awaitingTopic: String?, - awaitingPayloadType: Class<*>?, - receivedPayload: Any?) { - val newCheckpoint = Checkpoint(serialisedFiber, awaitingTopic, awaitingPayloadType?.name, receivedPayload) + request: FiberRequest) { + val newCheckpoint = Checkpoint(serialisedFiber, request) val previousCheckpoint = stateMachines.put(psm, newCheckpoint) if (previousCheckpoint != null) { checkpointStorage.removeCheckpoint(previousCheckpoint) @@ -249,9 +258,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: FiberRequest) { // We have a request to do something: send, receive, or send-and-receive. - if (request is FiberRequest.ExpectingResponse<*>) { - // Prepare a listener on the network that runs in the background thread when we receive a message. - checkpointOnExpectingResponse(psm, request) + when (request) { + is FiberRequest.ExpectingResponse<*> -> { + // Prepare a listener on the network that runs in the background thread when we receive a message. + checkpointOnExpectingResponse(psm, request) + } } // If a non-null payload to send was provided, send it now. request.payload?.let { @@ -277,9 +288,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService executor.checkOnThread() val topic = "${request.topic}.${request.sessionIDForReceive}" val serialisedFiber = serializeFiber(psm) - updateCheckpoint(psm, serialisedFiber, topic, request.responseType, null) + updateCheckpoint(psm, serialisedFiber, request) psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on topic $topic" } - iterateOnResponse(psm, request.responseType, serialisedFiber, topic) { + iterateOnResponse(psm, request.responseType, serialisedFiber, request) { try { Fiber.unpark(psm, QUASAR_UNBLOCKER) } catch(e: Throwable) { @@ -288,11 +299,16 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService } } + /** + * Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is + * received. + */ private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>, responseType: Class<*>, serialisedFiber: SerializedBytes>, - topic: String, + request: FiberRequest.ExpectingResponse<*>, resumeAction: (Any?) -> Unit) { + val topic = request.receiveTopic serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg -> // Assertion to ensure we don't execute on the wrong thread. executor.checkOnThread() @@ -305,38 +321,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService val payload = netMsg.data.deserialize() check(responseType.isInstance(payload)) { "Expected message of type ${responseType.name} but got ${payload.javaClass.name}" } // Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload - updateCheckpoint(psm, serialisedFiber, null, null, payload) + updateCheckpoint(psm, serialisedFiber, request) psm.logger.trace { "Received message of type ${payload.javaClass.name} on topic $topic (${payload.toString().abbreviate(50)})" } iterateStateMachine(psm, payload, resumeAction) } } - - // TODO: Clean this up - open class FiberRequest(val topic: String, - val destination: Party?, - val sessionIDForSend: Long, - val sessionIDForReceive: Long, - val payload: Any?) { - // This is used to identify where we suspended, in case of message mismatch errors and other things where we - // don't have the original stack trace because it's in a suspended fiber. - val stackTraceInCaseOfProblems = StackSnapshot() - - class ExpectingResponse( - topic: String, - destination: Party?, - sessionIDForSend: Long, - sessionIDForReceive: Long, - obj: Any?, - val responseType: Class - ) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) - - class NotExpectingResponse( - topic: String, - destination: Party, - sessionIDForSend: Long, - obj: Any? - ) : FiberRequest(topic, destination, sessionIDForSend, -1, obj) - } } class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem") diff --git a/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt b/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt index 93149ba6f5..6af826b8ba 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt @@ -3,8 +3,11 @@ package com.r3corda.node.services.persistence import com.google.common.jimfs.Configuration.unix import com.google.common.jimfs.Jimfs import com.google.common.primitives.Ints +import com.r3corda.core.random63BitValue import com.r3corda.core.serialization.SerializedBytes import com.r3corda.node.services.api.Checkpoint +import com.r3corda.node.services.statemachine.FiberRequest +import com.r3corda.node.services.statemachine.StateMachineManager import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.After @@ -92,6 +95,8 @@ class PerFileCheckpointStorageTests { } private var checkpointCount = 1 - private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), "topic", "javaType", null) + private val request = FiberRequest.ExpectingResponse("topic", null, random63BitValue(), random63BitValue(), null, + java.lang.String::class.java) + private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), request) } \ No newline at end of file