mirror of
https://github.com/corda/corda.git
synced 2025-02-07 11:30:22 +00:00
Refactored FiberRequest into cleaner ProtocolIORequest and fixed checkpoint regression
This commit is contained in:
parent
a2d7490902
commit
97e1a59770
@ -7,17 +7,20 @@ import com.r3corda.core.node.ServiceHub
|
|||||||
import com.r3corda.core.utilities.UntrustworthyData
|
import com.r3corda.core.utilities.UntrustworthyData
|
||||||
import org.slf4j.Logger
|
import org.slf4j.Logger
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* The interface of [ProtocolStateMachineImpl] exposing methods and properties required by ProtocolLogic for compilation.
|
* The interface of [ProtocolStateMachineImpl] exposing methods and properties required by ProtocolLogic for compilation.
|
||||||
*/
|
*/
|
||||||
interface ProtocolStateMachine<R> {
|
interface ProtocolStateMachine<R> {
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun <T : Any> sendAndReceive(topic: String, destination: Party, sessionIDForSend: Long, sessionIDForReceive: Long,
|
fun <T : Any> sendAndReceive(topic: String,
|
||||||
payload: Any, recvType: Class<T>): UntrustworthyData<T>
|
destination: Party,
|
||||||
|
sessionIDForSend: Long,
|
||||||
|
sessionIDForReceive: Long,
|
||||||
|
payload: Any,
|
||||||
|
receiveType: Class<T>): UntrustworthyData<T>
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun <T : Any> receive(topic: String, sessionIDForReceive: Long, recvType: Class<T>): UntrustworthyData<T>
|
fun <T : Any> receive(topic: String, sessionIDForReceive: Long, receiveType: Class<T>): UntrustworthyData<T>
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
fun send(topic: String, destination: Party, sessionID: Long, payload: Any)
|
fun send(topic: String, destination: Party, sessionID: Long, payload: Any)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package com.r3corda.node.services.api
|
package com.r3corda.node.services.api
|
||||||
|
|
||||||
import com.r3corda.core.serialization.SerializedBytes
|
import com.r3corda.core.serialization.SerializedBytes
|
||||||
import com.r3corda.node.services.statemachine.FiberRequest
|
import com.r3corda.node.services.statemachine.ProtocolIORequest
|
||||||
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
|
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -32,5 +32,6 @@ interface CheckpointStorage {
|
|||||||
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
|
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
|
||||||
data class Checkpoint(
|
data class Checkpoint(
|
||||||
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
||||||
val request: FiberRequest?
|
val request: ProtocolIORequest?,
|
||||||
|
val receivedPayload: Any?
|
||||||
)
|
)
|
@ -1,86 +0,0 @@
|
|||||||
package com.r3corda.node.services.statemachine
|
|
||||||
|
|
||||||
import com.r3corda.core.crypto.Party
|
|
||||||
import com.r3corda.core.messaging.TopicSession
|
|
||||||
|
|
||||||
// TODO: Clean this up
|
|
||||||
sealed class FiberRequest(val topic: String,
|
|
||||||
val destination: Party?,
|
|
||||||
val sessionIDForSend: Long,
|
|
||||||
val sessionIDForReceive: Long,
|
|
||||||
val payload: Any?) {
|
|
||||||
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
|
|
||||||
// don't have the original stack trace because it's in a suspended fiber.
|
|
||||||
@Transient
|
|
||||||
val stackTraceInCaseOfProblems: StackSnapshot? = StackSnapshot()
|
|
||||||
|
|
||||||
val receiveTopicSession: TopicSession
|
|
||||||
get() = TopicSession(topic, sessionIDForReceive)
|
|
||||||
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean
|
|
||||||
= if (other is FiberRequest) {
|
|
||||||
topic == other.topic
|
|
||||||
&& destination == other.destination
|
|
||||||
&& sessionIDForSend == other.sessionIDForSend
|
|
||||||
&& sessionIDForReceive == other.sessionIDForReceive
|
|
||||||
&& payload == other.payload
|
|
||||||
} else
|
|
||||||
false
|
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
|
||||||
var hash = 1L
|
|
||||||
|
|
||||||
hash = (hash * 31) + topic.hashCode()
|
|
||||||
hash = (hash * 31) + if (destination == null)
|
|
||||||
0
|
|
||||||
else
|
|
||||||
destination.hashCode()
|
|
||||||
hash = (hash * 31) + sessionIDForReceive
|
|
||||||
hash = (hash * 31) + sessionIDForReceive
|
|
||||||
hash = (hash * 31) + if (payload == null)
|
|
||||||
0
|
|
||||||
else
|
|
||||||
payload.hashCode()
|
|
||||||
|
|
||||||
return hash.toInt()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A fiber which is expecting a message response.
|
|
||||||
*/
|
|
||||||
class ExpectingResponse<R : Any>(
|
|
||||||
topic: String,
|
|
||||||
destination: Party?,
|
|
||||||
sessionIDForSend: Long,
|
|
||||||
sessionIDForReceive: Long,
|
|
||||||
obj: Any?,
|
|
||||||
type: Class<R>
|
|
||||||
) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) {
|
|
||||||
private val responseTypeName: String = type.name
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean {
|
|
||||||
return if (other is ExpectingResponse<*>) {
|
|
||||||
super.equals(other) && responseTypeName == other.responseTypeName
|
|
||||||
} else
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
override fun toString(): String {
|
|
||||||
return "Expecting response via topic $receiveTopicSession of type $responseTypeName"
|
|
||||||
}
|
|
||||||
|
|
||||||
// We have to do an unchecked cast, but unless the serialized form is damaged, this was
|
|
||||||
// correct when the request was instantiated
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
|
||||||
val responseType: Class<R>
|
|
||||||
get() = Class.forName(responseTypeName) as Class<R>
|
|
||||||
}
|
|
||||||
|
|
||||||
class NotExpectingResponse(
|
|
||||||
topic: String,
|
|
||||||
destination: Party,
|
|
||||||
sessionIDForSend: Long,
|
|
||||||
obj: Any?
|
|
||||||
) : FiberRequest(topic, destination, sessionIDForSend, -1, obj)
|
|
||||||
}
|
|
@ -0,0 +1,51 @@
|
|||||||
|
package com.r3corda.node.services.statemachine
|
||||||
|
|
||||||
|
import com.r3corda.core.crypto.Party
|
||||||
|
import com.r3corda.core.messaging.TopicSession
|
||||||
|
|
||||||
|
// TODO revisit when Kotlin 1.1 is released and data classes can extend other classes
|
||||||
|
interface ProtocolIORequest {
|
||||||
|
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
|
||||||
|
// don't have the original stack trace because it's in a suspended fiber.
|
||||||
|
val stackTraceInCaseOfProblems: StackSnapshot
|
||||||
|
val topic: String
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SendRequest : ProtocolIORequest {
|
||||||
|
val destination: Party
|
||||||
|
val payload: Any
|
||||||
|
val sendSessionID: Long
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ReceiveRequest<T> : ProtocolIORequest {
|
||||||
|
val receiveType: Class<T>
|
||||||
|
val receiveSessionID: Long
|
||||||
|
val receiveTopicSession: TopicSession get() = TopicSession(topic, receiveSessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
data class SendAndReceive<T>(override val topic: String,
|
||||||
|
override val destination: Party,
|
||||||
|
override val payload: Any,
|
||||||
|
override val sendSessionID: Long,
|
||||||
|
override val receiveType: Class<T>,
|
||||||
|
override val receiveSessionID: Long) : SendRequest, ReceiveRequest<T> {
|
||||||
|
@Transient
|
||||||
|
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
data class ReceiveOnly<T>(override val topic: String,
|
||||||
|
override val receiveType: Class<T>,
|
||||||
|
override val receiveSessionID: Long) : ReceiveRequest<T> {
|
||||||
|
@Transient
|
||||||
|
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
data class SendOnly(override val destination: Party,
|
||||||
|
override val topic: String,
|
||||||
|
override val payload: Any,
|
||||||
|
override val sendSessionID: Long) : SendRequest {
|
||||||
|
@Transient
|
||||||
|
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")
|
@ -29,7 +29,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
|
|||||||
|
|
||||||
// These fields shouldn't be serialised, so they are marked @Transient.
|
// These fields shouldn't be serialised, so they are marked @Transient.
|
||||||
@Transient lateinit override var serviceHub: ServiceHubInternal
|
@Transient lateinit override var serviceHub: ServiceHubInternal
|
||||||
@Transient internal lateinit var suspendAction: (FiberRequest) -> Unit
|
@Transient internal lateinit var suspendAction: (ProtocolIORequest) -> Unit
|
||||||
@Transient internal lateinit var actionOnEnd: () -> Unit
|
@Transient internal lateinit var actionOnEnd: () -> Unit
|
||||||
@Transient internal var receivedPayload: Any? = null
|
@Transient internal var receivedPayload: Any? = null
|
||||||
|
|
||||||
@ -77,43 +77,40 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable @Suppress("UNCHECKED_CAST")
|
@Suspendable
|
||||||
private fun <T : Any> suspendAndExpectReceive(with: FiberRequest): UntrustworthyData<T> {
|
private fun <T : Any> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): UntrustworthyData<T> {
|
||||||
suspend(with)
|
suspend(receiveRequest)
|
||||||
check(receivedPayload != null) { "Expected to receive something" }
|
check(receivedPayload != null) { "Expected to receive something" }
|
||||||
val untrustworthy = UntrustworthyData(receivedPayload as T)
|
val untrustworthy = UntrustworthyData(receiveRequest.receiveType.cast(receivedPayload))
|
||||||
receivedPayload = null
|
receivedPayload = null
|
||||||
return untrustworthy
|
return untrustworthy
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable @Suppress("UNCHECKED_CAST")
|
@Suspendable
|
||||||
override fun <T : Any> sendAndReceive(topic: String,
|
override fun <T : Any> sendAndReceive(topic: String,
|
||||||
destination: Party,
|
destination: Party,
|
||||||
sessionIDForSend: Long,
|
sessionIDForSend: Long,
|
||||||
sessionIDForReceive: Long,
|
sessionIDForReceive: Long,
|
||||||
payload: Any,
|
payload: Any,
|
||||||
recvType: Class<T>): UntrustworthyData<T> {
|
receiveType: Class<T>): UntrustworthyData<T> {
|
||||||
val result = FiberRequest.ExpectingResponse(topic, destination, sessionIDForSend, sessionIDForReceive, payload, recvType)
|
return suspendAndExpectReceive(SendAndReceive(topic, destination, payload, sessionIDForSend, receiveType, sessionIDForReceive))
|
||||||
return suspendAndExpectReceive(result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, recvType: Class<T>): UntrustworthyData<T> {
|
override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, receiveType: Class<T>): UntrustworthyData<T> {
|
||||||
val result = FiberRequest.ExpectingResponse(topic, null, -1, sessionIDForReceive, null, recvType)
|
return suspendAndExpectReceive(ReceiveOnly(topic, receiveType, sessionIDForReceive))
|
||||||
return suspendAndExpectReceive(result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) {
|
override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) {
|
||||||
val result = FiberRequest.NotExpectingResponse(topic, destination, sessionID, payload)
|
suspend(SendOnly(destination, topic, payload, sessionID))
|
||||||
suspend(result)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
private fun suspend(with: FiberRequest) {
|
private fun suspend(protocolIORequest: ProtocolIORequest) {
|
||||||
parkAndSerialize { fiber, serializer ->
|
parkAndSerialize { fiber, serializer ->
|
||||||
try {
|
try {
|
||||||
suspendAction(with)
|
suspendAction(protocolIORequest)
|
||||||
} catch (t: Throwable) {
|
} catch (t: Throwable) {
|
||||||
// Do not throw exception again - Quasar completely bins it.
|
// Do not throw exception again - Quasar completely bins it.
|
||||||
logger.warn("Captured exception which was swallowed by Quasar", t)
|
logger.warn("Captured exception which was swallowed by Quasar", t)
|
||||||
|
@ -125,38 +125,25 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
val fiber = deserializeFiber(checkpoint.serialisedFiber)
|
val fiber = deserializeFiber(checkpoint.serialisedFiber)
|
||||||
initFiber(fiber, { checkpoint })
|
initFiber(fiber, { checkpoint })
|
||||||
|
|
||||||
when (checkpoint.request) {
|
if (checkpoint.request is ReceiveRequest<*>) {
|
||||||
is FiberRequest.ExpectingResponse<*> -> {
|
val topicSession = checkpoint.request.receiveTopicSession
|
||||||
val topic = checkpoint.request.receiveTopicSession
|
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${checkpoint.request.receiveType.name} on $topicSession")
|
||||||
val awaitingPayloadType = checkpoint.request.responseType
|
iterateOnResponse(fiber, checkpoint.serialisedFiber, checkpoint.request) {
|
||||||
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${awaitingPayloadType.name} on topic $topic")
|
|
||||||
iterateOnResponse(fiber, awaitingPayloadType, checkpoint.serialisedFiber, checkpoint.request) {
|
|
||||||
try {
|
|
||||||
Fiber.unparkDeserialized(fiber, scheduler)
|
|
||||||
} catch (e: Throwable) {
|
|
||||||
logError(e, it, topic, fiber)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
is FiberRequest.NotExpectingResponse -> restoreNotExpectingResponse(fiber, checkpoint.request)
|
|
||||||
null -> restoreNotExpectingResponse(fiber)
|
|
||||||
else -> throw IllegalStateException("Unknown fiber request type " + checkpoint.request)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Restore a Fiber which was not expecting a response (either because it hasn't asked for one, such as a new
|
|
||||||
* Fiber, or because the detail it needed has arrived).
|
|
||||||
*/
|
|
||||||
private fun restoreNotExpectingResponse(fiber: ProtocolStateMachineImpl<*>, request: FiberRequest? = null) {
|
|
||||||
val payload = request?.payload
|
|
||||||
fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${payload.toString().abbreviate(50)}")
|
|
||||||
executor.executeASAP {
|
|
||||||
iterateStateMachine(fiber, payload) {
|
|
||||||
try {
|
try {
|
||||||
Fiber.unparkDeserialized(fiber, scheduler)
|
Fiber.unparkDeserialized(fiber, scheduler)
|
||||||
} catch (e: Throwable) {
|
} catch (e: Throwable) {
|
||||||
logError(e, it, null, fiber)
|
logError(e, it, topicSession, fiber)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${checkpoint.receivedPayload.toString().abbreviate(50)}")
|
||||||
|
executor.executeASAP {
|
||||||
|
iterateStateMachine(fiber, checkpoint.receivedPayload) {
|
||||||
|
try {
|
||||||
|
Fiber.unparkDeserialized(fiber, scheduler)
|
||||||
|
} catch (e: Throwable) {
|
||||||
|
logError(e, it, null, fiber)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -220,7 +207,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
|
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
|
||||||
// Need to add before iterating in case of immediate completion
|
// Need to add before iterating in case of immediate completion
|
||||||
initFiber(fiber) {
|
initFiber(fiber) {
|
||||||
val checkpoint = Checkpoint(serializeFiber(fiber), null)
|
val checkpoint = Checkpoint(serializeFiber(fiber), null, null)
|
||||||
checkpointStorage.addCheckpoint(checkpoint)
|
checkpointStorage.addCheckpoint(checkpoint)
|
||||||
checkpoint
|
checkpoint
|
||||||
}
|
}
|
||||||
@ -250,8 +237,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
|
|
||||||
private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>,
|
private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>,
|
||||||
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
||||||
request: FiberRequest) {
|
request: ProtocolIORequest?,
|
||||||
val newCheckpoint = Checkpoint(serialisedFiber, request)
|
receivedPayload: Any?) {
|
||||||
|
val newCheckpoint = Checkpoint(serialisedFiber, request, receivedPayload)
|
||||||
val previousCheckpoint = stateMachines.put(psm, newCheckpoint)
|
val previousCheckpoint = stateMachines.put(psm, newCheckpoint)
|
||||||
if (previousCheckpoint != null) {
|
if (previousCheckpoint != null) {
|
||||||
checkpointStorage.removeCheckpoint(previousCheckpoint)
|
checkpointStorage.removeCheckpoint(previousCheckpoint)
|
||||||
@ -269,62 +257,62 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
resumeAction(receivedPayload)
|
resumeAction(receivedPayload)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: FiberRequest) {
|
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: ProtocolIORequest) {
|
||||||
// We have a request to do something: send, receive, or send-and-receive.
|
// We have a request to do something: send, receive, or send-and-receive.
|
||||||
when (request) {
|
if (request is ReceiveRequest<*>) {
|
||||||
is FiberRequest.ExpectingResponse<*> -> {
|
// Prepare a listener on the network that runs in the background thread when we receive a message.
|
||||||
// Prepare a listener on the network that runs in the background thread when we receive a message.
|
prepareToReceiveForRequest(psm, request)
|
||||||
checkpointOnExpectingResponse(psm, request)
|
}
|
||||||
|
if (request is SendRequest) {
|
||||||
|
performSendRequest(psm, request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun prepareToReceiveForRequest(psm: ProtocolStateMachineImpl<*>, request: ReceiveRequest<*>) {
|
||||||
|
executor.checkOnThread()
|
||||||
|
val queueID = request.receiveTopicSession
|
||||||
|
val serialisedFiber = serializeFiber(psm)
|
||||||
|
updateCheckpoint(psm, serialisedFiber, request, null)
|
||||||
|
psm.logger.trace { "Preparing to receive message of type ${request.receiveType.name} on queue $queueID" }
|
||||||
|
iterateOnResponse(psm, serialisedFiber, request) {
|
||||||
|
try {
|
||||||
|
Fiber.unpark(psm, QUASAR_UNBLOCKER)
|
||||||
|
} catch(e: Throwable) {
|
||||||
|
logError(e, it, queueID, psm)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// If a non-null payload to send was provided, send it now.
|
}
|
||||||
val queueID = TopicSession(request.topic, request.sessionIDForSend)
|
|
||||||
request.payload?.let {
|
private fun performSendRequest(psm: ProtocolStateMachineImpl<*>, request: SendRequest) {
|
||||||
psm.logger.trace { "Sending message of type ${it.javaClass.name} using queue $queueID to ${request.destination} (${it.toString().abbreviate(50)})" }
|
val topicSession = TopicSession(request.topic, request.sendSessionID)
|
||||||
val node = serviceHub.networkMapCache.getNodeByLegalName(request.destination!!.name)
|
val payload = request.payload
|
||||||
if (node == null) {
|
psm.logger.trace { "Sending message of type ${payload.javaClass.name} using queue $topicSession to ${request.destination} (${payload.toString().abbreviate(50)})" }
|
||||||
throw IllegalArgumentException("Don't know about ${request.destination} but trying to send a message of type ${it.javaClass.name} on $queueID (${it.toString().abbreviate(50)})", request.stackTraceInCaseOfProblems)
|
val node = serviceHub.networkMapCache.getNodeByLegalName(request.destination.name) ?:
|
||||||
}
|
throw IllegalArgumentException("Don't know about ${request.destination} but trying to send a message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})", request.stackTraceInCaseOfProblems)
|
||||||
serviceHub.networkService.send(queueID, it, node.address)
|
serviceHub.networkService.send(topicSession, payload, node.address)
|
||||||
}
|
|
||||||
if (request is FiberRequest.NotExpectingResponse) {
|
if (request is SendOnly) {
|
||||||
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
|
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
|
||||||
iterateStateMachine(psm, null) {
|
iterateStateMachine(psm, null) {
|
||||||
try {
|
try {
|
||||||
Fiber.unpark(psm, QUASAR_UNBLOCKER)
|
Fiber.unpark(psm, QUASAR_UNBLOCKER)
|
||||||
} catch(e: Throwable) {
|
} catch(e: Throwable) {
|
||||||
logError(e, request.payload, queueID, psm)
|
logError(e, request.payload, topicSession, psm)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun checkpointOnExpectingResponse(psm: ProtocolStateMachineImpl<*>, request: FiberRequest.ExpectingResponse<*>) {
|
|
||||||
executor.checkOnThread()
|
|
||||||
val queueID = request.receiveTopicSession
|
|
||||||
val serialisedFiber = serializeFiber(psm)
|
|
||||||
updateCheckpoint(psm, serialisedFiber, request)
|
|
||||||
psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on queue $queueID" }
|
|
||||||
iterateOnResponse(psm, request.responseType, serialisedFiber, request) {
|
|
||||||
try {
|
|
||||||
Fiber.unpark(psm, QUASAR_UNBLOCKER)
|
|
||||||
} catch(e: Throwable) {
|
|
||||||
logError(e, it, queueID, psm)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is
|
* Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is
|
||||||
* received.
|
* received.
|
||||||
*/
|
*/
|
||||||
private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>,
|
private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>,
|
||||||
responseType: Class<*>,
|
|
||||||
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
||||||
request: FiberRequest.ExpectingResponse<*>,
|
request: ReceiveRequest<*>,
|
||||||
resumeAction: (Any?) -> Unit) {
|
resumeAction: (Any?) -> Unit) {
|
||||||
val topic = request.receiveTopicSession
|
val topicSession = request.receiveTopicSession
|
||||||
serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg ->
|
serviceHub.networkService.runOnNextMessage(topicSession, executor) { netMsg ->
|
||||||
// Assertion to ensure we don't execute on the wrong thread.
|
// Assertion to ensure we don't execute on the wrong thread.
|
||||||
executor.checkOnThread()
|
executor.checkOnThread()
|
||||||
// TODO: This is insecure: we should not deserialise whatever we find and *then* check.
|
// TODO: This is insecure: we should not deserialise whatever we find and *then* check.
|
||||||
@ -334,13 +322,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
// at the last moment when we do the downcast. However this would make protocol code harder to read and
|
// at the last moment when we do the downcast. However this would make protocol code harder to read and
|
||||||
// make it more difficult to migrate to a more explicit serialisation scheme later.
|
// make it more difficult to migrate to a more explicit serialisation scheme later.
|
||||||
val payload = netMsg.data.deserialize<Any>()
|
val payload = netMsg.data.deserialize<Any>()
|
||||||
check(responseType.isInstance(payload)) { "Expected message of type ${responseType.name} but got ${payload.javaClass.name}" }
|
check(request.receiveType.isInstance(payload)) { "Expected message of type ${request.receiveType.name} but got ${payload.javaClass.name}" }
|
||||||
// Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload
|
// Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload
|
||||||
updateCheckpoint(psm, serialisedFiber, request)
|
updateCheckpoint(psm, serialisedFiber, null, payload)
|
||||||
psm.logger.trace { "Received message of type ${payload.javaClass.name} on topic ${request.topic}.${request.sessionIDForReceive} (${payload.toString().abbreviate(50)})" }
|
psm.logger.trace { "Received message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})" }
|
||||||
iterateStateMachine(psm, payload, resumeAction)
|
iterateStateMachine(psm, payload, resumeAction)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")
|
|
||||||
|
@ -3,10 +3,8 @@ package com.r3corda.node.services.persistence
|
|||||||
import com.google.common.jimfs.Configuration.unix
|
import com.google.common.jimfs.Configuration.unix
|
||||||
import com.google.common.jimfs.Jimfs
|
import com.google.common.jimfs.Jimfs
|
||||||
import com.google.common.primitives.Ints
|
import com.google.common.primitives.Ints
|
||||||
import com.r3corda.core.random63BitValue
|
|
||||||
import com.r3corda.core.serialization.SerializedBytes
|
import com.r3corda.core.serialization.SerializedBytes
|
||||||
import com.r3corda.node.services.api.Checkpoint
|
import com.r3corda.node.services.api.Checkpoint
|
||||||
import com.r3corda.node.services.statemachine.FiberRequest
|
|
||||||
import org.assertj.core.api.Assertions.assertThat
|
import org.assertj.core.api.Assertions.assertThat
|
||||||
import org.assertj.core.api.Assertions.assertThatExceptionOfType
|
import org.assertj.core.api.Assertions.assertThatExceptionOfType
|
||||||
import org.junit.After
|
import org.junit.After
|
||||||
@ -94,8 +92,6 @@ class PerFileCheckpointStorageTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private var checkpointCount = 1
|
private var checkpointCount = 1
|
||||||
private val request = FiberRequest.ExpectingResponse("topic", null, random63BitValue(), random63BitValue(), null,
|
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), null, null)
|
||||||
kotlin.String::class.java)
|
|
||||||
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), request)
|
|
||||||
|
|
||||||
}
|
}
|
@ -2,62 +2,76 @@ package com.r3corda.node.services.statemachine
|
|||||||
|
|
||||||
import co.paralleluniverse.fibers.Fiber
|
import co.paralleluniverse.fibers.Fiber
|
||||||
import co.paralleluniverse.fibers.Suspendable
|
import co.paralleluniverse.fibers.Suspendable
|
||||||
|
import com.r3corda.core.crypto.Party
|
||||||
|
import com.r3corda.core.node.NodeInfo
|
||||||
import com.r3corda.core.protocols.ProtocolLogic
|
import com.r3corda.core.protocols.ProtocolLogic
|
||||||
import com.r3corda.node.services.MockServiceHubInternal
|
import com.r3corda.core.random63BitValue
|
||||||
import com.r3corda.node.services.api.Checkpoint
|
import com.r3corda.testing.node.MockNetwork
|
||||||
import com.r3corda.node.services.api.CheckpointStorage
|
import com.r3corda.testing.node.MockNetwork.MockNode
|
||||||
import com.r3corda.node.services.api.MessagingServiceInternal
|
|
||||||
import com.r3corda.testing.node.InMemoryMessagingNetwork
|
|
||||||
import com.r3corda.node.utilities.AffinityExecutor
|
|
||||||
import org.assertj.core.api.Assertions.assertThat
|
import org.assertj.core.api.Assertions.assertThat
|
||||||
import org.junit.After
|
import org.junit.After
|
||||||
|
import org.junit.Before
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import java.util.*
|
|
||||||
|
|
||||||
class StateMachineManagerTests {
|
class StateMachineManagerTests {
|
||||||
|
|
||||||
val checkpointStorage = RecordingCheckpointStorage()
|
val net = MockNetwork()
|
||||||
val network = InMemoryMessagingNetwork(false).InMemoryMessaging(true, InMemoryMessagingNetwork.Handle(1, "mock"))
|
lateinit var node1: MockNode
|
||||||
val smm = createManager()
|
lateinit var node2: MockNode
|
||||||
|
|
||||||
|
@Before
|
||||||
|
fun start() {
|
||||||
|
val nodes = net.createTwoNodes()
|
||||||
|
node1 = nodes.first
|
||||||
|
node2 = nodes.second
|
||||||
|
net.runNetwork()
|
||||||
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
fun cleanUp() {
|
fun cleanUp() {
|
||||||
network.stop()
|
net.stopNodes()
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `newly added protocol is preserved on restart`() {
|
fun `newly added protocol is preserved on restart`() {
|
||||||
smm.add("test", ProtocolWithoutCheckpoints())
|
node1.smm.add("test", ProtocolWithoutCheckpoints())
|
||||||
// Ensure we're restoring from the original add checkpoint
|
val restoredProtocol = node1.restartAndGetRestoredProtocol<ProtocolWithoutCheckpoints>()
|
||||||
assertThat(checkpointStorage.allCheckpoints).hasSize(1)
|
|
||||||
val restoredProtocol = createManager().run {
|
|
||||||
start()
|
|
||||||
findStateMachines(ProtocolWithoutCheckpoints::class.java).single().first
|
|
||||||
}
|
|
||||||
assertThat(restoredProtocol.protocolStarted).isTrue()
|
assertThat(restoredProtocol.protocolStarted).isTrue()
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `protocol can lazily use the serviceHub in its constructor`() {
|
fun `protocol can lazily use the serviceHub in its constructor`() {
|
||||||
val protocol = ProtocolWithLazyServiceHub()
|
val protocol = ProtocolWithLazyServiceHub()
|
||||||
smm.add("test", protocol)
|
node1.smm.add("test", protocol)
|
||||||
assertThat(protocol.lazyTime).isNotNull()
|
assertThat(protocol.lazyTime).isNotNull()
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun createManager() = StateMachineManager(object : MockServiceHubInternal() {
|
@Test
|
||||||
override val networkService: MessagingServiceInternal get() = network
|
fun `protocol suspended just after receiving payload`() {
|
||||||
}, emptyList(), checkpointStorage, AffinityExecutor.SAME_THREAD)
|
val topic = "send-and-receive"
|
||||||
|
val sessionID = random63BitValue()
|
||||||
|
val payload = random63BitValue()
|
||||||
|
node1.smm.add("test", SendProtocol(topic, node2.info.identity, sessionID, payload))
|
||||||
|
node2.smm.add("test", ReceiveProtocol(topic, sessionID))
|
||||||
|
net.runNetwork()
|
||||||
|
node2.stop()
|
||||||
|
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info)
|
||||||
|
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: NodeInfo? = null): P {
|
||||||
|
val node = mockNet.createNode(networkMapAddress, id)
|
||||||
|
return node.smm.findStateMachines(P::class.java).single().first
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() {
|
||||||
private class ProtocolWithoutCheckpoints : ProtocolLogic<Unit>() {
|
|
||||||
|
|
||||||
@Transient var protocolStarted = false
|
@Transient var protocolStarted = false
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
override fun call() {
|
override fun doCall() {
|
||||||
protocolStarted = true
|
protocolStarted = true
|
||||||
Fiber.park()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override val topic: String get() = throw UnsupportedOperationException()
|
override val topic: String get() = throw UnsupportedOperationException()
|
||||||
@ -75,21 +89,37 @@ class StateMachineManagerTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class RecordingCheckpointStorage : CheckpointStorage {
|
private class SendProtocol(override val topic: String, val destination: Party, val sessionID: Long, val payload: Any) : ProtocolLogic<Unit>() {
|
||||||
|
@Suspendable
|
||||||
|
override fun call() = send(destination, sessionID, payload)
|
||||||
|
}
|
||||||
|
|
||||||
private val _checkpoints = ArrayList<Checkpoint>()
|
|
||||||
val allCheckpoints = ArrayList<Checkpoint>()
|
|
||||||
|
|
||||||
override fun addCheckpoint(checkpoint: Checkpoint) {
|
private class ReceiveProtocol(override val topic: String, val sessionID: Long) : NonTerminatingProtocol() {
|
||||||
_checkpoints.add(checkpoint)
|
|
||||||
allCheckpoints.add(checkpoint)
|
@Transient var receivedPayload: Any? = null
|
||||||
|
|
||||||
|
@Suspendable
|
||||||
|
override fun doCall() {
|
||||||
|
receivedPayload = receive<Any>(sessionID).validate { it }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A protocol that suspends forever after doing some work. This is to allow it to be retrieved from the SMM after
|
||||||
|
* restart for testing checkpoint restoration. Store any results as @Transient fields.
|
||||||
|
*/
|
||||||
|
private abstract class NonTerminatingProtocol : ProtocolLogic<Unit>() {
|
||||||
|
|
||||||
|
@Suspendable
|
||||||
|
override fun call() {
|
||||||
|
doCall()
|
||||||
|
Fiber.park()
|
||||||
}
|
}
|
||||||
|
|
||||||
override fun removeCheckpoint(checkpoint: Checkpoint) {
|
@Suspendable
|
||||||
_checkpoints.remove(checkpoint)
|
abstract fun doCall()
|
||||||
}
|
|
||||||
|
|
||||||
override val checkpoints: Iterable<Checkpoint> get() = _checkpoints
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user